1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "onnx.pb.h"
16
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32 std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33 if (!fs.is_open())
34 {
35 fprintf(stderr, "open failed %s\n", filepath);
36 return false;
37 }
38
39 google::protobuf::io::IstreamInputStream input(&fs);
40 google::protobuf::io::CodedInputStream codedstr(&input);
41
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43 codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45 codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47
48 bool success = message->ParseFromCodedStream(&codedstr);
49
50 fs.close();
51
52 return success;
53 }
54
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57 std::vector<int> v;
58
59 for (int i = 0; i < node.attribute_size(); i++)
60 {
61 const onnx::AttributeProto& attr = node.attribute(i);
62 if (attr.name() == key)
63 {
64 v.resize(attr.ints_size());
65 for (int j = 0; j < attr.ints_size(); j++)
66 {
67 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68 }
69
70 break;
71 }
72 }
73
74 return v;
75 }
76
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79 std::vector<float> v;
80
81 for (int i = 0; i < node.attribute_size(); i++)
82 {
83 const onnx::AttributeProto& attr = node.attribute(i);
84 if (attr.name() == key)
85 {
86 v.resize(attr.floats_size());
87 for (int j = 0; j < attr.floats_size(); j++)
88 {
89 v[j] = attr.floats(j);
90 }
91
92 break;
93 }
94 }
95
96 return v;
97 }
98
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101 for (int i = 0; i < node.attribute_size(); i++)
102 {
103 const onnx::AttributeProto& attr = node.attribute(i);
104 if (attr.name() == key)
105 {
106 return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107 }
108 }
109
110 return def;
111 }
112
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115 for (int i = 0; i < node.attribute_size(); i++)
116 {
117 const onnx::AttributeProto& attr = node.attribute(i);
118 if (attr.name() == key)
119 {
120 return attr.f();
121 }
122 }
123
124 return def;
125 }
126
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129 for (int i = 0; i < node.attribute_size(); i++)
130 {
131 const onnx::AttributeProto& attr = node.attribute(i);
132 if (attr.name() == key)
133 {
134 return attr.s();
135 }
136 }
137
138 return def;
139 }
140
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143 for (int i = 0; i < node.attribute_size(); i++)
144 {
145 const onnx::AttributeProto& attr = node.attribute(i);
146 if (attr.name() == key)
147 {
148 return attr.t();
149 }
150 }
151
152 return onnx::TensorProto();
153 }
154
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157 float v = 0.f;
158
159 // float
160 if (tp.data_type() == 1)
161 {
162 const float* shape_data = 0;
163 if (tp.has_raw_data())
164 {
165 shape_data = (const float*)tp.raw_data().data();
166 }
167 else
168 {
169 shape_data = tp.float_data().data();
170 }
171 v = shape_data[0];
172 }
173 // double
174 else if (tp.data_type() == 11)
175 {
176 const double* shape_data = 0;
177 if (tp.has_raw_data())
178 {
179 shape_data = (const double*)tp.raw_data().data();
180 }
181 else
182 {
183 shape_data = tp.double_data().data();
184 }
185 v = shape_data[0];
186 }
187 // int64
188 else if (tp.data_type() == 7)
189 {
190 const int64_t* shape_data = 0;
191 if (tp.has_raw_data())
192 {
193 shape_data = (const int64_t*)tp.raw_data().data();
194 }
195 else
196 {
197 shape_data = tp.int64_data().data();
198 }
199 v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200 }
201 // int32
202 else if (tp.data_type() == 6)
203 {
204 const int32_t* shape_data = 0;
205 if (tp.has_raw_data())
206 {
207 shape_data = (const int32_t*)tp.raw_data().data();
208 }
209 else
210 {
211 shape_data = tp.int32_data().data();
212 }
213 v = shape_data[0];
214 }
215 else
216 {
217 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218 abort();
219 }
220
221 return v;
222 }
223
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226 int size = 0;
227
228 std::vector<int> v;
229
230 // int64
231 if (tp.data_type() == 7)
232 {
233 const int64_t* shape_data = 0;
234 if (tp.has_raw_data())
235 {
236 shape_data = (const int64_t*)tp.raw_data().data();
237 size = (int)(tp.raw_data().size() / 8);
238 }
239 else
240 {
241 shape_data = tp.int64_data().data();
242 size = tp.int64_data_size();
243 }
244 for (int j = 0; j < size; j++)
245 {
246 int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247 v.push_back(vi);
248 }
249 }
250 // int32
251 else if (tp.data_type() == 6)
252 {
253 const int32_t* shape_data = 0;
254 if (tp.has_raw_data())
255 {
256 shape_data = (const int32_t*)tp.raw_data().data();
257 size = (int)(tp.raw_data().size() / 4);
258 }
259 else
260 {
261 shape_data = tp.int32_data().data();
262 size = tp.int32_data_size();
263 }
264 for (int j = 0; j < size; j++)
265 {
266 v.push_back(shape_data[j]);
267 }
268 }
269 else
270 {
271 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272 }
273
274 return v;
275 }
276
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279 int size = 0;
280
281 std::vector<float> v;
282
283 // float
284 if (tp.data_type() == 1)
285 {
286 const float* shape_data = 0;
287 if (tp.has_raw_data())
288 {
289 shape_data = (const float*)tp.raw_data().data();
290 size = (int)(tp.raw_data().size() / 4);
291 }
292 else
293 {
294 shape_data = tp.float_data().data();
295 size = tp.float_data_size();
296 }
297 for (int j = 0; j < size; j++)
298 {
299 v.push_back(shape_data[j]);
300 }
301 }
302 // double
303 else if (tp.data_type() == 11)
304 {
305 const double* shape_data = 0;
306 if (tp.has_raw_data())
307 {
308 shape_data = (const double*)tp.raw_data().data();
309 size = (int)(tp.raw_data().size() / 8);
310 }
311 else
312 {
313 shape_data = tp.double_data().data();
314 size = tp.double_data_size();
315 }
316 for (int j = 0; j < size; j++)
317 {
318 v.push_back((float)shape_data[j]);
319 }
320 }
321 else
322 {
323 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324 }
325
326 return v;
327 }
328
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331 if (tp.has_raw_data())
332 {
333 const std::string& raw_data = tp.raw_data();
334 int size = (int)raw_data.size() / 4;
335 return size;
336 }
337 else if (tp.data_type() == 1)
338 {
339 return tp.float_data_size();
340 }
341
342 return 0;
343 }
344
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347 int size = get_tensor_proto_data_size(tp);
348
349 if (tp.has_raw_data())
350 {
351 const std::string& raw_data = tp.raw_data();
352 fwrite(raw_data.data(), sizeof(float), size, bp);
353 }
354 else if (tp.data_type() == 1)
355 {
356 fwrite(tp.float_data().data(), sizeof(float), size, bp);
357 }
358 }
359
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362 int node_count = mutable_graph->node_size();
363 for (int i = 0; i < node_count; i++)
364 {
365 onnx::NodeProto* node = mutable_graph->mutable_node(i);
366
367 // weight <= Reshape(weight)
368 if (node->op_type() == "Reshape")
369 {
370 // check weight
371 if (weights.find(node->input(0)) == weights.end())
372 continue;
373
374 weights[node->output(0)] = weights[node->input(0)];
375
376 // set weight shape directly
377 std::vector<int> shape;
378 if (node->input_size() == 1)
379 {
380 shape = get_node_attr_ai(*node, "shape");
381 }
382 else if (node->input_size() == 2)
383 {
384 // opset 5
385 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386 }
387
388 weights[node->output(0)].clear_dims();
389 for (int j = 0; j < shape.size(); j++)
390 {
391 weights[node->output(0)].add_dims(shape[j]);
392 }
393
394 // reduce
395 node->set_op_type("noop_reducedncnn");
396
397 node_reference[node->input(0)] -= 1;
398 if (node->input_size() == 2)
399 {
400 node_reference[node->input(1)] -= 1;
401 }
402
403 reduced_node_count += 1;
404 i += 1;
405 }
406 }
407 }
408
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411 int node_count = mutable_graph->node_size();
412 for (int i = 0; i < node_count; i++)
413 {
414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
415
416 // weight <= Transpose(weight)
417 if (node->op_type() == "Transpose")
418 {
419 // check weight
420 if (weights.find(node->input(0)) == weights.end())
421 continue;
422
423 if (weights[node->input(0)].dims_size() != 2)
424 continue;
425
426 // perm = (1, 0)
427 std::vector<int> perm = get_node_attr_ai(*node, "perm");
428 if (perm.size() != 2)
429 continue;
430 if (perm[0] != 1 || perm[1] != 0)
431 continue;
432
433 weights[node->output(0)] = weights[node->input(0)];
434
435 // permute weight
436 {
437 onnx::TensorProto& B = weights[node->output(0)];
438
439 const int h = B.dims(0);
440 const int w = B.dims(1);
441
442 std::vector<float> permuted_data;
443 permuted_data.reserve((size_t)h * w);
444 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445
446 for (int j = 0; j < w; j++)
447 {
448 for (int k = 0; k < h; k++)
449 {
450 float vb = bptr[k * w + j];
451 permuted_data.push_back(vb);
452 }
453 }
454
455 B.set_dims(0, w);
456 B.set_dims(1, h);
457
458 if (B.has_raw_data())
459 {
460 B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461 }
462 else
463 {
464 for (int j = 0; j < (int)permuted_data.size(); j++)
465 B.set_float_data(j, permuted_data[j]);
466 }
467 }
468
469 // reduce
470 node->set_op_type("noop_reducedncnn");
471
472 node_reference[node->input(0)] -= 1;
473
474 reduced_node_count += 1;
475 i += 1;
476 }
477 }
478 }
479
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482 int node_count = mutable_graph->node_size();
483 for (int i = 0; i < node_count; i++)
484 {
485 onnx::NodeProto* node = mutable_graph->mutable_node(i);
486
487 // ShuffleChannel <= Reshape - Transpose - Reshape
488 // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489 if (node->op_type() == "Reshape")
490 {
491 if (node_reference[node->output(0)] != 1)
492 continue;
493
494 std::vector<int> shape;
495 if (node->input_size() == 1)
496 {
497 shape = get_node_attr_ai(*node, "shape");
498 }
499 else
500 {
501 // skip weight reshape
502 if (weights.find(node->input(1)) == weights.end())
503 continue;
504
505 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506 }
507
508 // 1 groups channels_per_group, height, width
509 // reverse style = channels_per_group, groups, height * width
510 if (shape.size() != 5 && shape.size() != 3)
511 continue;
512
513 if (shape.size() == 5 && shape[0] != 1)
514 continue;
515
516 if (i + 2 >= node_count)
517 continue;
518
519 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521
522 if (node3->op_type() == "Constant")
523 {
524 if (i + 3 >= node_count)
525 continue;
526
527 node3 = mutable_graph->mutable_node(i + 3);
528 }
529
530 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531 continue;
532
533 if (node_reference[node2->output(0)] != 1)
534 continue;
535
536 // 0 2 1 3 4
537 // reverse style = 1 0 2
538 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539 if (perm.size() != 5 && perm.size() != 3)
540 continue;
541
542 if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543 continue;
544
545 if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546 continue;
547
548 std::vector<int> shape3;
549 if (node3->input_size() == 1)
550 {
551 shape3 = get_node_attr_ai(*node3, "shape");
552 }
553 else
554 {
555 // skip weight reshape
556 if (weights.find(node3->input(1)) == weights.end())
557 continue;
558
559 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560 }
561
562 // 1, -1, height, width
563 // reverse style = group, -1, channels_per_group, height, width
564 if (shape3.size() != 4 && shape3.size() != 5)
565 continue;
566
567 if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568 continue;
569
570 if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571 continue;
572
573 // reduce
574 node->set_op_type("noop_reducedncnn");
575 node2->set_op_type("noop_reducedncnn");
576
577 if (node->input_size() == 2)
578 {
579 node_reference[node->input(1)] -= 1;
580 }
581 node_reference[node->output(0)] -= 1;
582 node_reference[node2->output(0)] -= 1;
583 if (node3->input_size() == 2)
584 {
585 node_reference[node3->input(1)] -= 1;
586 }
587
588 blob_names.erase(node->output(0));
589 blob_names.erase(node2->output(0));
590
591 node3->set_op_type("ShuffleChannel");
592 node3->set_input(0, node->input(0));
593
594 onnx::AttributeProto* attr_group = node3->add_attribute();
595 attr_group->set_name("group");
596 attr_group->set_i(shape[1]);
597
598 onnx::AttributeProto* attr_reverse = node3->add_attribute();
599 attr_reverse->set_name("reverse");
600 attr_reverse->set_i(shape.size() == 3);
601
602 reduced_node_count += 2;
603 i += 2;
604 }
605 }
606 }
607
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610 int node_count = mutable_graph->node_size();
611 for (int i = 0; i < node_count; i++)
612 {
613 onnx::NodeProto* node = mutable_graph->mutable_node(i);
614
615 // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616 if (node->op_type() == "ShuffleChannel")
617 {
618 // reverse = 1
619 int reverse = get_node_attr_i(*node, "reverse");
620 if (reverse != 1)
621 continue;
622
623 if (i + 2 >= node_count)
624 continue;
625
626 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628
629 if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630 continue;
631
632 if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633 continue;
634
635 // axis = 0
636 int gather2_axis = get_node_attr_i(*node2, "axis");
637 if (gather2_axis != 0)
638 continue;
639
640 // indices = 0
641 if (weights.find(node2->input(1)) == weights.end())
642 continue;
643
644 std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645 if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646 continue;
647
648 // axis = 0
649 int gather3_axis = get_node_attr_i(*node3, "axis");
650 if (gather3_axis != 0)
651 continue;
652
653 // indices = 1
654 if (weights.find(node3->input(1)) == weights.end())
655 continue;
656
657 std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658 if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659 continue;
660
661 // reduce
662 node2->set_op_type("noop_reducedncnn");
663
664 node_reference[node->output(0)] -= 2;
665 node_reference[node2->input(1)] -= 1;
666 node_reference[node3->input(1)] -= 1;
667
668 node3->set_op_type("Split");
669 node3->clear_input();
670 node3->add_input(node->output(0));
671 node3->add_output(node3->output(0));
672 node3->set_output(0, node2->output(0));
673
674 node3->clear_attribute();
675 onnx::AttributeProto* attr_axis = node3->add_attribute();
676 attr_axis->set_name("axis");
677 attr_axis->set_i(1);
678
679 reduced_node_count += 1;
680 i += 1;
681 }
682 }
683 }
684
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687 int node_count = mutable_graph->node_size();
688 for (int i = 0; i < node_count; i++)
689 {
690 onnx::NodeProto* node = mutable_graph->mutable_node(i);
691
692 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696 // out = x * F.relu6(x + 3, inplace=True) / 6
697 if (node->op_type() == "Add")
698 {
699 if (node_reference[node->output(0)] != 1)
700 continue;
701
702 if (i + 3 >= node_count)
703 continue;
704
705 if (weights.find(node->input(1)) == weights.end())
706 continue;
707
708 const onnx::TensorProto& add_three = weights[node->input(1)];
709 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710 continue;
711
712 float constant_add_three = get_node_attr_from_input_f(add_three);
713 if (constant_add_three != 3.f)
714 continue;
715
716 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719
720 if (node4->op_type() == "Constant")
721 {
722 if (i + 4 >= node_count)
723 continue;
724
725 node4 = mutable_graph->mutable_node(i + 4);
726 }
727
728 if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729 continue;
730
731 if (node_reference[node2->output(0)] != 1)
732 continue;
733
734 float relu6_min;
735 float relu6_max;
736 if (node2->input_size() == 1)
737 {
738 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740 }
741 else
742 {
743 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745
746 relu6_min = get_node_attr_from_input_f(min_tp);
747 relu6_max = get_node_attr_from_input_f(max_tp);
748 }
749 if (relu6_min != 0.f || relu6_max != 6.f)
750 continue;
751
752 if (node_reference[node3->output(0)] != 1)
753 continue;
754
755 if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756 continue;
757
758 if (weights.find(node4->input(1)) == weights.end())
759 continue;
760
761 const onnx::TensorProto& div_six = weights[node4->input(1)];
762 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763 continue;
764
765 float constant_div_six = get_node_attr_from_input_f(div_six);
766 if (node4->op_type() == "Div" && constant_div_six != 6.f)
767 continue;
768 if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769 continue;
770
771 // reduce
772 node->set_op_type("noop_reducedncnn");
773 node2->set_op_type("noop_reducedncnn");
774 node3->set_op_type("noop_reducedncnn");
775
776 node_reference[node->input(0)] -= 1;
777 node_reference[node->input(1)] -= 1;
778 node_reference[node->output(0)] -= 1;
779 if (node2->input_size() == 3)
780 {
781 node_reference[node2->input(1)] -= 1;
782 node_reference[node2->input(2)] -= 1;
783 }
784 node_reference[node2->output(0)] -= 1;
785 node_reference[node3->output(0)] -= 1;
786 node_reference[node4->input(1)] -= 1;
787
788 blob_names.erase(node->output(0));
789 blob_names.erase(node2->output(0));
790 blob_names.erase(node3->output(0));
791
792 node4->set_op_type("HardSwish");
793 node4->clear_input();
794 node4->add_input(node->input(0));
795
796 onnx::AttributeProto* attr_alpha = node4->add_attribute();
797 attr_alpha->set_name("alpha");
798 attr_alpha->set_f(1.f / 6.f);
799
800 onnx::AttributeProto* attr_beta = node4->add_attribute();
801 attr_beta->set_name("beta");
802 attr_beta->set_f(3.f / 6.f);
803
804 reduced_node_count += 3;
805 i += 3;
806 }
807 }
808
809 for (int i = 0; i < node_count; i++)
810 {
811 onnx::NodeProto* node = mutable_graph->mutable_node(i);
812
813 // HardSwish <= HardSigmoid - Mul
814 // out = x * hsigmoid(x)
815 if (node->op_type() == "HardSigmoid")
816 {
817 if (node_reference[node->output(0)] != 1)
818 continue;
819
820 float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821 float beta = get_node_attr_f(*node, "beta", 0.5f);
822
823 if (i + 1 >= node_count)
824 continue;
825
826 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827
828 if (node2->op_type() != "Mul")
829 continue;
830
831 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832 continue;
833
834 // reduce
835 node->set_op_type("noop_reducedncnn");
836
837 node_reference[node->input(0)] -= 1;
838 node_reference[node->output(0)] -= 1;
839
840 blob_names.erase(node->output(0));
841
842 node2->set_op_type("HardSwish");
843 node2->clear_input();
844 node2->add_input(node->input(0));
845
846 onnx::AttributeProto* attr_alpha = node2->add_attribute();
847 attr_alpha->set_name("alpha");
848 attr_alpha->set_f(alpha);
849
850 onnx::AttributeProto* attr_beta = node2->add_attribute();
851 attr_beta->set_name("beta");
852 attr_beta->set_f(beta);
853
854 reduced_node_count += 1;
855 i += 1;
856 }
857 }
858 }
859
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862 int node_count = mutable_graph->node_size();
863 for (int i = 0; i < node_count; i++)
864 {
865 onnx::NodeProto* node = mutable_graph->mutable_node(i);
866
867 // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868 // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871 // out = F.relu6(x + 3, inplace=True) / 6
872 if (node->op_type() == "Add")
873 {
874 if (node_reference[node->output(0)] != 1)
875 continue;
876
877 if (i + 2 >= node_count)
878 continue;
879
880 if (weights.find(node->input(1)) == weights.end())
881 continue;
882
883 const onnx::TensorProto& add_three = weights[node->input(1)];
884 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885 continue;
886
887 float constant_add_three = get_node_attr_from_input_f(add_three);
888 if (constant_add_three != 3.f)
889 continue;
890
891 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893
894 if (node3->op_type() == "Constant")
895 {
896 if (i + 3 >= node_count)
897 continue;
898
899 node3 = mutable_graph->mutable_node(i + 3);
900 }
901
902 if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903 continue;
904
905 if (node_reference[node2->output(0)] != 1)
906 continue;
907
908 float relu6_min;
909 float relu6_max;
910 if (node2->input_size() == 1)
911 {
912 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914 }
915 else
916 {
917 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919
920 relu6_min = get_node_attr_from_input_f(min_tp);
921 relu6_max = get_node_attr_from_input_f(max_tp);
922 }
923 if (relu6_min != 0.f || relu6_max != 6.f)
924 continue;
925
926 if (weights.find(node3->input(1)) == weights.end())
927 continue;
928
929 const onnx::TensorProto& div_six = weights[node3->input(1)];
930 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931 continue;
932
933 float constant_div_six = get_node_attr_from_input_f(div_six);
934 if (node3->op_type() == "Div" && constant_div_six != 6.f)
935 continue;
936 if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937 continue;
938
939 // reduce
940 node->set_op_type("noop_reducedncnn");
941 node2->set_op_type("noop_reducedncnn");
942
943 node_reference[node->input(1)] -= 1;
944 node_reference[node->output(0)] -= 1;
945 if (node2->input_size() == 3)
946 {
947 node_reference[node2->input(1)] -= 1;
948 node_reference[node2->input(2)] -= 1;
949 }
950 node_reference[node2->output(0)] -= 1;
951 node_reference[node3->input(1)] -= 1;
952
953 blob_names.erase(node->output(0));
954 blob_names.erase(node2->output(0));
955
956 node3->set_op_type("HardSigmoid");
957 node3->clear_input();
958 node3->add_input(node->input(0));
959
960 onnx::AttributeProto* attr_alpha = node3->add_attribute();
961 attr_alpha->set_name("alpha");
962 attr_alpha->set_f(1.f / 6.f);
963
964 onnx::AttributeProto* attr_beta = node3->add_attribute();
965 attr_beta->set_name("beta");
966 attr_beta->set_f(3.f / 6.f);
967
968 reduced_node_count += 2;
969 i += 2;
970 }
971 }
972 }
973
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976 int node_count = mutable_graph->node_size();
977 for (int i = 0; i < node_count; i++)
978 {
979 onnx::NodeProto* node = mutable_graph->mutable_node(i);
980
981 // Swish <= Sigmoid - Mul
982 // x * torch.sigmoid(x)
983 if (node->op_type() == "Sigmoid")
984 {
985 if (node_reference[node->output(0)] != 1)
986 continue;
987
988 if (i + 1 >= node_count)
989 continue;
990
991 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992
993 if (node2->op_type() != "Mul")
994 continue;
995
996 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997 continue;
998
999 // reduce
1000 node->set_op_type("noop_reducedncnn");
1001
1002 node_reference[node->input(0)] -= 1;
1003 node_reference[node->output(0)] -= 1;
1004
1005 blob_names.erase(node->output(0));
1006
1007 node2->set_op_type("Swish");
1008 node2->clear_input();
1009 node2->add_input(node->input(0));
1010
1011 reduced_node_count += 1;
1012 i += 1;
1013 }
1014 }
1015 }
1016
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019 int node_count = mutable_graph->node_size();
1020 for (int i = 0; i < node_count; i++)
1021 {
1022 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023
1024 // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025 if (node->op_type() == "Unsqueeze")
1026 {
1027 if (node_reference[node->output(0)] != 1)
1028 continue;
1029
1030 if (i + 2 >= node_count)
1031 continue;
1032
1033 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035
1036 if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037 continue;
1038
1039 if (node_reference[node2->output(0)] != 1)
1040 continue;
1041
1042 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043 continue;
1044
1045 // reduce
1046 node->set_op_type("noop_reducedncnn");
1047 node3->set_op_type("noop_reducedncnn");
1048
1049 node_reference[node->output(0)] -= 1;
1050 node_reference[node2->output(0)] -= 1;
1051
1052 blob_names.erase(node->output(0));
1053 blob_names.erase(node2->output(0));
1054
1055 node2->set_input(0, node->input(0));
1056 node2->set_output(0, node3->output(0));
1057
1058 reduced_node_count += 2;
1059 i += 2;
1060 }
1061 }
1062 }
1063
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066 int node_count = mutable_graph->node_size();
1067 for (int i = 0; i < node_count; i++)
1068 {
1069 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070
1071 // PReLU <= Unsqueeze - PReLU
1072 if (node->op_type() == "Unsqueeze")
1073 {
1074 // check weight
1075 if (weights.find(node->input(0)) == weights.end())
1076 continue;
1077
1078 onnx::TensorProto& B = weights[node->input(0)];
1079 if (B.dims_size() != 1)
1080 continue;
1081
1082 if (node_reference[node->output(0)] != 1)
1083 continue;
1084
1085 // axes = (1, 2)
1086 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087 if (axes.size() != 2)
1088 continue;
1089 if (axes[0] != 1 || axes[1] != 2)
1090 continue;
1091
1092 if (i + 1 >= node_count)
1093 continue;
1094
1095 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096
1097 if (node2->op_type() != "PRelu")
1098 continue;
1099
1100 if (node2->input(1) != node->output(0))
1101 continue;
1102
1103 // reduce
1104 node->set_op_type("noop_reducedncnn");
1105
1106 node_reference[node->output(0)] -= 1;
1107
1108 blob_names.erase(node->output(0));
1109
1110 node2->set_input(1, node->input(0));
1111
1112 reduced_node_count += 1;
1113 i += 1;
1114 }
1115 }
1116 }
1117
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120 int node_count = mutable_graph->node_size();
1121 for (int i = 0; i < node_count; i++)
1122 {
1123 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124
1125 // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126 // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127 if (node->op_type() == "ReduceL2")
1128 {
1129 if (node_reference[node->output(0)] != 1)
1130 continue;
1131
1132 // axes = (1)
1133 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134 if (axes.size() != 1)
1135 continue;
1136 if (axes[0] != 1)
1137 continue;
1138
1139 if (i + 3 >= node_count)
1140 continue;
1141
1142 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145
1146 bool has_shape_node = node3->op_type() == "Shape";
1147 onnx::NodeProto* node_shape = 0;
1148 if (has_shape_node)
1149 {
1150 if (i + 4 >= node_count)
1151 continue;
1152
1153 node_shape = node3;
1154 node3 = mutable_graph->mutable_node(i + 3);
1155 node4 = mutable_graph->mutable_node(i + 4);
1156 }
1157
1158 if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159 continue;
1160
1161 if (node_reference[node2->output(0)] != 1)
1162 continue;
1163
1164 if (node_reference[node3->output(0)] != 1)
1165 continue;
1166
1167 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168 || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169 continue;
1170
1171 if (has_shape_node)
1172 {
1173 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174 continue;
1175 }
1176
1177 // +eps
1178 float clip_min;
1179 if (node2->input_size() == 1)
1180 {
1181 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182 }
1183 else
1184 {
1185 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186
1187 clip_min = get_node_attr_from_input_f(min_tp);
1188 }
1189
1190 // reduce
1191 node->set_op_type("noop_reducedncnn");
1192 node2->set_op_type("noop_reducedncnn");
1193 if (has_shape_node)
1194 {
1195 node_shape->set_op_type("noop_reducedncnn");
1196 }
1197 node3->set_op_type("noop_reducedncnn");
1198
1199 node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200 node_reference[node->output(0)] -= 1;
1201 node_reference[node2->output(0)] -= 1;
1202 if (has_shape_node)
1203 {
1204 node_reference[node_shape->output(0)] -= 1;
1205 }
1206 node_reference[node3->output(0)] -= 1;
1207 if (node3->input_size() == 2)
1208 {
1209 node_reference[node3->input(1)] -= 1;
1210 }
1211
1212 blob_names.erase(node->output(0));
1213 blob_names.erase(node2->output(0));
1214 if (has_shape_node)
1215 {
1216 blob_names.erase(node_shape->output(0));
1217 }
1218 blob_names.erase(node3->output(0));
1219
1220 node4->set_op_type("Normalize");
1221 node4->clear_input();
1222 node4->add_input(node->input(0));
1223
1224 onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225 attr_alpha->set_name("eps");
1226 attr_alpha->set_f(clip_min);
1227
1228 reduced_node_count += has_shape_node ? 4 : 3;
1229 i += has_shape_node ? 4 : 3;
1230 }
1231 }
1232 }
1233
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236 int node_count = mutable_graph->node_size();
1237 for (int i = 0; i < node_count; i++)
1238 {
1239 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240
1241 // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242 if (node->op_type() == "Reshape")
1243 {
1244 if (node_reference[node->output(0)] != 1)
1245 continue;
1246
1247 std::vector<int> shape;
1248 if (node->input_size() == 1)
1249 {
1250 shape = get_node_attr_ai(*node, "shape");
1251 }
1252 else
1253 {
1254 // skip weight reshape
1255 if (weights.find(node->input(1)) == weights.end())
1256 continue;
1257
1258 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259 }
1260
1261 // 0, group, -1
1262 if (shape.size() != 3)
1263 continue;
1264
1265 if (shape[0] != 0 || shape[2] != -1)
1266 continue;
1267
1268 int groups = shape[1];
1269
1270 if (i + 4 >= node_count)
1271 continue;
1272
1273 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277
1278 if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279 continue;
1280
1281 if (node_reference[node2->output(0)] != 1)
1282 continue;
1283
1284 if (node_reference[node3->output(0)] != 1)
1285 continue;
1286
1287 if (node_reference[node4->output(0)] != 1)
1288 continue;
1289
1290 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291 || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292 continue;
1293
1294 // +eps
1295 float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296
1297 // InstanceNormalization S=1 B=0
1298 std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299 std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300 if ((int)S.size() != groups || (int)B.size() != groups)
1301 continue;
1302
1303 bool instancenorm_affine = false;
1304 for (int j = 0; j < groups; j++)
1305 {
1306 if (S[j] != 1.f || B[j] != 0.f)
1307 {
1308 instancenorm_affine = true;
1309 break;
1310 }
1311 }
1312
1313 if (instancenorm_affine)
1314 continue;
1315
1316 std::vector<int> shape2;
1317 if (node3->input_size() == 1)
1318 {
1319 shape2 = get_node_attr_ai(*node3, "shape");
1320 }
1321 else
1322 {
1323 // skip weight reshape
1324 if (weights.find(node3->input(1)) == weights.end())
1325 continue;
1326
1327 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328 }
1329
1330 // 1, channels, w, h
1331 if (shape2.size() != 4)
1332 continue;
1333
1334 if (shape2[0] != 1)
1335 continue;
1336
1337 int channels = shape2[1];
1338
1339 // affine
1340 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343 {
1344 // no affine
1345 }
1346 else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347 {
1348 // we only allow per-channel affine
1349 continue;
1350 }
1351
1352 // reduce
1353 node->set_op_type("noop_reducedncnn");
1354 node2->set_op_type("noop_reducedncnn");
1355 node3->set_op_type("noop_reducedncnn");
1356 node4->set_op_type("noop_reducedncnn");
1357
1358 if (node->input_size() == 2)
1359 {
1360 node_reference[node->input(1)] -= 1;
1361 }
1362 node_reference[node->output(0)] -= 1;
1363 node_reference[node2->input(1)] -= 1;
1364 node_reference[node2->input(2)] -= 1;
1365 node_reference[node2->output(0)] -= 1;
1366 if (node3->input_size() == 2)
1367 {
1368 node_reference[node3->input(1)] -= 1;
1369 }
1370 node_reference[node3->output(0)] -= 1;
1371 node_reference[node4->output(0)] -= 1;
1372
1373 blob_names.erase(node->output(0));
1374 blob_names.erase(node2->output(0));
1375 blob_names.erase(node3->output(0));
1376 blob_names.erase(node4->output(0));
1377
1378 std::string affine_scale = node4->input(1);
1379 std::string affine_bias = node5->input(1);
1380
1381 node5->set_op_type("GroupNorm");
1382 node5->clear_input();
1383 node5->add_input(node->input(0));
1384 node5->add_input(affine_scale);
1385 node5->add_input(affine_bias);
1386
1387 onnx::AttributeProto* attr_groups = node5->add_attribute();
1388 attr_groups->set_name("groups");
1389 attr_groups->set_i(groups);
1390
1391 onnx::AttributeProto* attr_channels = node5->add_attribute();
1392 attr_channels->set_name("channels");
1393 attr_channels->set_i(channels);
1394
1395 onnx::AttributeProto* attr_eps = node5->add_attribute();
1396 attr_eps->set_name("epsilon");
1397 attr_eps->set_f(eps);
1398
1399 onnx::AttributeProto* attr_affine = node5->add_attribute();
1400 attr_affine->set_name("affine");
1401 attr_affine->set_i(1);
1402
1403 reduced_node_count += 4;
1404 i += 4;
1405 }
1406 }
1407 }
1408
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411 int node_count = mutable_graph->node_size();
1412 for (int i = 0; i < node_count; i++)
1413 {
1414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415
1416 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418 if (node->op_type() == "ReduceMean")
1419 {
1420 if (node_reference[node->output(0)] != 1)
1421 continue;
1422
1423 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424
1425 // -1
1426 // -2 -1
1427 if (axes.size() != 1 && axes.size() != 2)
1428 continue;
1429
1430 int normed_axes = (int)axes.size();
1431 if (normed_axes == 1 && axes[0] != -1)
1432 continue;
1433 if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434 continue;
1435
1436 if (i + 6 >= node_count)
1437 continue;
1438
1439 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445
1446 if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447 continue;
1448
1449 if (node_reference[node2->output(0)] != 2)
1450 continue;
1451
1452 if (node_reference[node3->output(0)] != 1)
1453 continue;
1454
1455 if (node_reference[node4->output(0)] != 1)
1456 continue;
1457
1458 if (node_reference[node5->output(0)] != 1)
1459 continue;
1460
1461 if (node_reference[node6->output(0)] != 1)
1462 continue;
1463
1464 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465 || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466 || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467 || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468 continue;
1469
1470 if (weights.find(node3->input(1)) == weights.end())
1471 continue;
1472
1473 const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474 if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475 continue;
1476
1477 float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478 if (constant_pow_two != 2.f)
1479 continue;
1480
1481 std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482
1483 // -1
1484 // -2 -1
1485 if ((int)axes4.size() != normed_axes)
1486 continue;
1487
1488 if (normed_axes == 1 && axes4[0] != -1)
1489 continue;
1490 if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491 continue;
1492
1493 if (weights.find(node5->input(1)) == weights.end())
1494 continue;
1495
1496 const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497 if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498 continue;
1499
1500 float eps = get_node_attr_from_input_f(add_eps);
1501
1502 int affine = 0;
1503 while (i + 8 < node_count)
1504 {
1505 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507
1508 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509 break;
1510
1511 if (node_reference[node7->output(0)] != 1)
1512 break;
1513
1514 if (node_reference[node8->output(0)] != 1)
1515 break;
1516
1517 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518 break;
1519
1520 // affine
1521 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523 if (affine_S.size() != affine_B.size())
1524 break;
1525
1526 affine = 1;
1527 break;
1528 }
1529
1530 // reduce
1531 node->set_op_type("noop_reducedncnn");
1532 node2->set_op_type("noop_reducedncnn");
1533 node3->set_op_type("noop_reducedncnn");
1534 node4->set_op_type("noop_reducedncnn");
1535 node5->set_op_type("noop_reducedncnn");
1536 node6->set_op_type("noop_reducedncnn");
1537
1538 node_reference[node->input(0)] -= 1;
1539 node_reference[node2->input(0)] -= 1;
1540 node_reference[node2->input(1)] -= 1;
1541 node_reference[node3->input(0)] -= 1;
1542 node_reference[node3->input(1)] -= 1;
1543 node_reference[node4->input(0)] -= 1;
1544 node_reference[node5->input(0)] -= 1;
1545 node_reference[node5->input(1)] -= 1;
1546 node_reference[node6->input(0)] -= 1;
1547 node_reference[node7->input(0)] -= 1;
1548 node_reference[node7->input(1)] -= 1;
1549
1550 blob_names.erase(node->output(0));
1551 blob_names.erase(node2->output(0));
1552 blob_names.erase(node3->output(0));
1553 blob_names.erase(node4->output(0));
1554 blob_names.erase(node5->output(0));
1555 blob_names.erase(node6->output(0));
1556
1557 node_reference[node->input(0)] += 1;
1558
1559 if (affine == 0)
1560 {
1561 node7->set_op_type("LayerNorm");
1562 node7->clear_input();
1563 node7->add_input(node->input(0));
1564
1565 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566 attr_eps->set_name("epsilon");
1567 attr_eps->set_f(eps);
1568
1569 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570 attr_affine->set_name("affine");
1571 attr_affine->set_i(affine);
1572
1573 reduced_node_count += 6;
1574 i += 6;
1575 }
1576 else // if (affine == 1)
1577 {
1578 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580
1581 node7->set_op_type("noop_reducedncnn");
1582 node8->set_op_type("noop_reducedncnn");
1583
1584 node_reference[node8->input(0)] -= 1;
1585 node_reference[node9->input(0)] -= 1;
1586
1587 blob_names.erase(node7->output(0));
1588 blob_names.erase(node8->output(0));
1589
1590 std::string affine_scale = node8->input(1);
1591 std::string affine_bias = node9->input(1);
1592
1593 node9->set_op_type("LayerNorm");
1594 node9->clear_input();
1595 node9->add_input(node->input(0));
1596 node9->add_input(affine_scale);
1597 node9->add_input(affine_bias);
1598
1599 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600 attr_eps->set_name("epsilon");
1601 attr_eps->set_f(eps);
1602
1603 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604 attr_affine->set_name("affine");
1605 attr_affine->set_i(affine);
1606
1607 reduced_node_count += 8;
1608 i += 8;
1609 }
1610 }
1611 }
1612 }
1613
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616 int node_count = mutable_graph->node_size();
1617 for (int i = 0; i < node_count; i++)
1618 {
1619 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620
1621 // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622 if (node->op_type() == "Shape")
1623 {
1624 if (node_reference[node->output(0)] != 1)
1625 continue;
1626
1627 if (i + 6 >= node_count)
1628 continue;
1629
1630 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636
1637 if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638 || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639 continue;
1640
1641 if (node_reference[node2->output(0)] != 1)
1642 continue;
1643
1644 // if (node_reference[node3->output(0)] != 1)
1645 // continue;
1646
1647 if (node_reference[node4->output(0)] != 1)
1648 continue;
1649
1650 if (node_reference[node5->output(0)] != 1)
1651 continue;
1652
1653 if (node_reference[node6->output(0)] != 1)
1654 continue;
1655
1656 if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657 || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658 || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659 continue;
1660
1661 // axis = 0
1662 int gather_axis = get_node_attr_i(*node2, "axis");
1663 if (gather_axis != 0)
1664 continue;
1665
1666 // indices = 0
1667 if (weights.find(node2->input(1)) == weights.end())
1668 continue;
1669
1670 std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671 if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672 continue;
1673
1674 // axes = (0)
1675 std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676 if (unsqueeze_axes.size() != 1)
1677 continue;
1678 if (unsqueeze_axes[0] != 0)
1679 continue;
1680
1681 // axes = (0)
1682 std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683 if (unsqueeze2_axes.size() != 1)
1684 continue;
1685 if (unsqueeze2_axes[0] != 0)
1686 continue;
1687
1688 // data = -1
1689 if (weights.find(node5->input(0)) == weights.end())
1690 continue;
1691
1692 std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693 if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694 continue;
1695
1696 // axis = 0
1697 int concat_axis = get_node_attr_i(*node6, "axis");
1698 if (concat_axis != 0)
1699 continue;
1700
1701 // reduce
1702 node->set_op_type("noop_reducedncnn");
1703 node2->set_op_type("noop_reducedncnn");
1704 // node3->set_op_type("noop_reducedncnn");
1705 node4->set_op_type("noop_reducedncnn");
1706 node5->set_op_type("noop_reducedncnn");
1707 node6->set_op_type("noop_reducedncnn");
1708
1709 node_reference[node->input(0)] -= 1;
1710 node_reference[node->output(0)] -= 1;
1711 node_reference[node2->input(1)] -= 1;
1712 node_reference[node2->output(0)] -= 1;
1713 // node_reference[node3->output(0)] -= 1;
1714 node_reference[node4->output(0)] -= 1;
1715 node_reference[node5->input(0)] -= 1;
1716 node_reference[node5->output(0)] -= 1;
1717 node_reference[node6->output(0)] -= 1;
1718
1719 blob_names.erase(node->output(0));
1720 blob_names.erase(node2->output(0));
1721 // blob_names.erase(node3->output(0));
1722 blob_names.erase(node4->output(0));
1723 blob_names.erase(node5->output(0));
1724 blob_names.erase(node6->output(0));
1725
1726 node7->set_op_type("Flatten");
1727 node7->clear_input();
1728 node7->add_input(node->input(0));
1729
1730 reduced_node_count += 5;
1731 i += 5;
1732 }
1733 }
1734 }
1735
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738 int node_count = mutable_graph->node_size();
1739 for (int i = 0; i < node_count; i++)
1740 {
1741 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742
1743 // PixelShuffle <= Reshape - Transpose - Reshape
1744 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745 if (node->op_type() == "Reshape")
1746 {
1747 if (node_reference[node->output(0)] != 1)
1748 continue;
1749
1750 std::vector<int> shape;
1751 if (node->input_size() == 1)
1752 {
1753 shape = get_node_attr_ai(*node, "shape");
1754 }
1755 else
1756 {
1757 // skip weight reshape
1758 if (weights.find(node->input(1)) == weights.end())
1759 continue;
1760
1761 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762 }
1763
1764 // -1, 3, upscale_factor, upscale_factor, height, width
1765 if (shape.size() != 6)
1766 continue;
1767
1768 if (shape[0] != 1 && shape[0] != -1)
1769 continue;
1770
1771 if (shape[2] != shape[3])
1772 continue;
1773
1774 if (i + 2 >= node_count)
1775 continue;
1776
1777 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779
1780 if (node3->op_type() == "Constant")
1781 {
1782 if (i + 3 >= node_count)
1783 continue;
1784
1785 node3 = mutable_graph->mutable_node(i + 3);
1786 }
1787
1788 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789 continue;
1790
1791 if (node_reference[node2->output(0)] != 1)
1792 continue;
1793
1794 // 0 1 4 2 5 3
1795 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796 if (perm.size() != 6)
1797 continue;
1798
1799 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800 continue;
1801
1802 std::vector<int> shape3;
1803 if (node3->input_size() == 1)
1804 {
1805 shape3 = get_node_attr_ai(*node3, "shape");
1806 }
1807 else
1808 {
1809 // skip weight reshape
1810 if (weights.find(node3->input(1)) == weights.end())
1811 continue;
1812
1813 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814 }
1815
1816 // -1, 3, height, width
1817 if (shape3.size() != 4)
1818 continue;
1819
1820 if (shape3[0] != 1 && shape3[0] != -1)
1821 continue;
1822
1823 if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824 continue;
1825
1826 // reduce
1827 node->set_op_type("noop_reducedncnn");
1828 node2->set_op_type("noop_reducedncnn");
1829
1830 if (node->input_size() == 2)
1831 {
1832 node_reference[node->input(1)] -= 1;
1833 }
1834 node_reference[node->output(0)] -= 1;
1835 node_reference[node2->output(0)] -= 1;
1836 if (node3->input_size() == 2)
1837 {
1838 node_reference[node3->input(1)] -= 1;
1839 }
1840
1841 blob_names.erase(node->output(0));
1842 blob_names.erase(node2->output(0));
1843
1844 node3->set_op_type("PixelShuffle");
1845 node3->set_input(0, node->input(0));
1846
1847 onnx::AttributeProto* attr_group = node3->add_attribute();
1848 attr_group->set_name("scale_factor");
1849 attr_group->set_i(shape[2]);
1850
1851 reduced_node_count += 2;
1852 i += 2;
1853 }
1854 }
1855 }
1856
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859 int node_count = mutable_graph->node_size();
1860 for (int i = 0; i < node_count; i++)
1861 {
1862 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863
1864 // PixelShuffle <= Reshape - Transpose - Reshape
1865 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866 if (node->op_type() == "Reshape")
1867 {
1868 if (node_reference[node->output(0)] != 1)
1869 continue;
1870
1871 std::vector<int> shape;
1872 if (node->input_size() == 1)
1873 {
1874 shape = get_node_attr_ai(*node, "shape");
1875 }
1876 else
1877 {
1878 // skip weight reshape
1879 if (weights.find(node->input(1)) == weights.end())
1880 continue;
1881
1882 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883 }
1884
1885 // -1, 3, out_height, block_size, out_width, block_size
1886 if (shape.size() != 6)
1887 continue;
1888
1889 if (shape[0] != 1 && shape[0] != -1)
1890 continue;
1891
1892 if (shape[3] != shape[5])
1893 continue;
1894
1895 if (i + 2 >= node_count)
1896 continue;
1897
1898 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900
1901 if (node3->op_type() == "Constant")
1902 {
1903 if (i + 3 >= node_count)
1904 continue;
1905
1906 node3 = mutable_graph->mutable_node(i + 3);
1907 }
1908
1909 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910 continue;
1911
1912 if (node_reference[node2->output(0)] != 1)
1913 continue;
1914
1915 // 0 1 3 5 2 4
1916 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917 if (perm.size() != 6)
1918 continue;
1919
1920 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921 continue;
1922
1923 std::vector<int> shape3;
1924 if (node3->input_size() == 1)
1925 {
1926 shape3 = get_node_attr_ai(*node3, "shape");
1927 }
1928 else
1929 {
1930 // skip weight reshape
1931 if (weights.find(node3->input(1)) == weights.end())
1932 continue;
1933
1934 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935 }
1936
1937 // -1, out_channels, out_height, out_width
1938 if (shape3.size() != 4)
1939 continue;
1940
1941 if (shape3[0] != 1 && shape3[0] != -1)
1942 continue;
1943
1944 if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945 continue;
1946
1947 // reduce
1948 node->set_op_type("noop_reducedncnn");
1949 node2->set_op_type("noop_reducedncnn");
1950
1951 if (node->input_size() == 2)
1952 {
1953 node_reference[node->input(1)] -= 1;
1954 }
1955 node_reference[node->output(0)] -= 1;
1956 node_reference[node2->output(0)] -= 1;
1957 if (node3->input_size() == 2)
1958 {
1959 node_reference[node3->input(1)] -= 1;
1960 }
1961
1962 blob_names.erase(node->output(0));
1963 blob_names.erase(node2->output(0));
1964
1965 node3->set_op_type("Reorg");
1966 node3->set_input(0, node->input(0));
1967
1968 onnx::AttributeProto* attr_group = node3->add_attribute();
1969 attr_group->set_name("stride");
1970 attr_group->set_i(shape[3]);
1971
1972 reduced_node_count += 2;
1973 i += 2;
1974 }
1975 }
1976 }
1977
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980 int node_count = mutable_graph->node_size();
1981 for (int i = 0; i < node_count; i++)
1982 {
1983 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984
1985 // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986 if (node->op_type() == "Expand")
1987 {
1988 if (node_reference[node->output(0)] != 1)
1989 continue;
1990
1991 if (i + 1 >= node_count)
1992 continue;
1993
1994 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995
1996 if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997 continue;
1998
1999 if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000 continue;
2001
2002 // reduce
2003 node->set_op_type("noop_reducedncnn");
2004
2005 node_reference[node->output(0)] -= 1;
2006 if (node->input_size() == 2)
2007 {
2008 node_reference[node->input(1)] -= 1;
2009 }
2010
2011 blob_names.erase(node->output(0));
2012
2013 node2->set_input(1, node->input(0));
2014
2015 reduced_node_count += 1;
2016 i += 1;
2017 }
2018 }
2019 }
2020
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2021 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2022 {
2023 int node_count = mutable_graph->node_size();
2024 for (int i = 0; i < node_count; i++)
2025 {
2026 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2027
2028 // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2029 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2030 {
2031 if (node_reference[node->output(0)] != 1)
2032 continue;
2033
2034 if (i + 2 >= node_count)
2035 continue;
2036
2037 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2038 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2039
2040 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2041 continue;
2042
2043 if (node_reference[node2->output(0)] != 1)
2044 continue;
2045
2046 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2047 continue;
2048
2049 std::string direction = get_node_attr_s(*node, "direction");
2050 if (direction != "bidirectional")
2051 continue;
2052
2053 // 0 2 1 3
2054 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2055 if (perm.size() != 4)
2056 continue;
2057
2058 if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2059 continue;
2060
2061 std::vector<int> shape;
2062 if (node3->input_size() == 1)
2063 {
2064 shape = get_node_attr_ai(*node3, "shape");
2065 }
2066 else
2067 {
2068 // skip weight reshape
2069 if (weights.find(node3->input(1)) == weights.end())
2070 continue;
2071
2072 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2073 }
2074
2075 // 0 0 -1
2076 if (shape.size() != 3)
2077 continue;
2078
2079 if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2080 continue;
2081
2082 // reduce
2083 node2->set_op_type("noop_reducedncnn");
2084 node3->set_op_type("noop_reducedncnn");
2085
2086 node_reference[node->output(0)] -= 1;
2087 node_reference[node2->output(0)] -= 1;
2088 if (node3->input_size() == 2)
2089 {
2090 node_reference[node3->input(1)] -= 1;
2091 }
2092
2093 blob_names.erase(node->output(0));
2094 blob_names.erase(node2->output(0));
2095
2096 node->set_output(0, node3->output(0));
2097
2098 reduced_node_count += 2;
2099 i += 2;
2100
2101 if (i + 1 < node_count)
2102 {
2103 if (node_reference[node3->output(0)] != 1)
2104 continue;
2105
2106 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2107
2108 if (node4->op_type() != "Transpose")
2109 continue;
2110
2111 if (node4->input(0) != node->output(0))
2112 continue;
2113
2114 // 1 0 2
2115 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2116 if (perm4.size() != 3)
2117 continue;
2118
2119 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2120 continue;
2121
2122 // reduce
2123 node4->set_op_type("noop_reducedncnn");
2124
2125 node_reference[node->output(0)] -= 1;
2126
2127 blob_names.erase(node->output(0));
2128
2129 node->clear_output();
2130 node->add_output(node4->output(0));
2131
2132 reduced_node_count += 1;
2133 i += 1;
2134 }
2135 }
2136 }
2137
2138 for (int i = 0; i < node_count; i++)
2139 {
2140 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2141
2142 // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2143 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2144 {
2145 if (node_reference[node->output(0)] != 1)
2146 continue;
2147
2148 if (i + 1 >= node_count)
2149 continue;
2150
2151 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2152
2153 if (node2->op_type() != "Squeeze")
2154 continue;
2155
2156 if (node2->input(0) != node->output(0))
2157 continue;
2158
2159 std::string direction = get_node_attr_s(*node, "direction");
2160 if (direction == "bidirectional")
2161 continue;
2162
2163 // 1
2164 std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2165 if (axes.size() != 1)
2166 continue;
2167
2168 if (axes[0] != 1)
2169 continue;
2170
2171 // reduce
2172 node2->set_op_type("noop_reducedncnn");
2173
2174 node_reference[node->output(0)] -= 1;
2175
2176 blob_names.erase(node->output(0));
2177
2178 node->set_output(0, node2->output(0));
2179
2180 reduced_node_count += 1;
2181 i += 1;
2182
2183 if (i + 1 < node_count)
2184 {
2185 if (node_reference[node2->output(0)] != 1)
2186 continue;
2187
2188 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2189
2190 if (node3->op_type() != "Transpose")
2191 continue;
2192
2193 if (node3->input(0) != node->output(0))
2194 continue;
2195
2196 // 1 0 2
2197 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2198 if (perm4.size() != 3)
2199 continue;
2200
2201 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2202 continue;
2203
2204 // reduce
2205 node3->set_op_type("noop_reducedncnn");
2206
2207 node_reference[node->output(0)] -= 1;
2208
2209 blob_names.erase(node->output(0));
2210
2211 node->clear_output();
2212 node->add_output(node3->output(0));
2213
2214 reduced_node_count += 1;
2215 i += 1;
2216 }
2217 }
2218 }
2219
2220 for (int i = 0; i < node_count; i++)
2221 {
2222 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2223
2224 // LSTM <= Transpose - LSTM
2225 if (node->op_type() == "Transpose")
2226 {
2227 if (node_reference[node->output(0)] != 1)
2228 continue;
2229
2230 // 1 0 2
2231 std::vector<int> perm = get_node_attr_ai(*node, "perm");
2232 if (perm.size() != 3)
2233 continue;
2234
2235 if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2236 continue;
2237
2238 if (i + 1 >= node_count)
2239 continue;
2240
2241 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2242
2243 if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2244 continue;
2245
2246 if (node2->input(0) != node->output(0))
2247 continue;
2248
2249 // reduce
2250 node->set_op_type("noop_reducedncnn");
2251
2252 node_reference[node->output(0)] -= 1;
2253
2254 blob_names.erase(node->output(0));
2255
2256 node2->set_input(0, node->input(0));
2257
2258 reduced_node_count += 1;
2259 i += 1;
2260 }
2261 }
2262 }
2263
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2264 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2265 {
2266 int node_count = mutable_graph->node_size();
2267 for (int i = 0; i < node_count; i++)
2268 {
2269 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2270
2271 // MultiHeadAttention <= MatMul(q) - Add
2272 // - MatMul(k) - Add
2273 // - MatMul(v) - Add
2274 // - Mul
2275 // - Reshape - Transpose
2276 // - Reshape - Reshape - Transpose - Transpose
2277 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2278 if (node->op_type() == "MatMul")
2279 {
2280 if (i + 19 >= node_count)
2281 continue;
2282
2283 if (node_reference[node->output(0)] != 1)
2284 continue;
2285
2286 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2287 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2288 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2289 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2290 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2291 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2292 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2293 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2294 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2295 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2296 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2297 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2298 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2299 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2300 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2301 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2302 onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2303 onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2304 onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2305
2306 if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2307 continue;
2308
2309 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2310 continue;
2311
2312 if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2313 continue;
2314
2315 std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2316 std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2317 std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2318 std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2319
2320 if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2321 continue;
2322
2323 int embed_dim = q_B.size();
2324
2325 // 1 0 2
2326 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2327 std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2328 if (perm9.size() != 3 || perm12.size() != 3)
2329 continue;
2330
2331 if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2332 continue;
2333
2334 // 1 2 0
2335 std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2336 if (perm13.size() != 3)
2337 continue;
2338
2339 if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2340 continue;
2341
2342 // 1 0 2
2343 std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2344 if (perm17.size() != 3)
2345 continue;
2346
2347 if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2348 continue;
2349
2350 int softmax_axis = get_node_attr_i(*node15, "axis");
2351 if (softmax_axis != 2)
2352 continue;
2353
2354 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2355 std::vector<int> shape8;
2356 std::vector<int> shape10;
2357 std::vector<int> shape11;
2358 if (node8->input_size() == 1)
2359 {
2360 shape8 = get_node_attr_ai(*node8, "shape");
2361 }
2362 else
2363 {
2364 // skip weight reshape
2365 if (weights.find(node8->input(1)) == weights.end())
2366 continue;
2367
2368 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2369 }
2370 if (node10->input_size() == 1)
2371 {
2372 shape10 = get_node_attr_ai(*node10, "shape");
2373 }
2374 else
2375 {
2376 // skip weight reshape
2377 if (weights.find(node10->input(1)) == weights.end())
2378 continue;
2379
2380 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2381 }
2382 if (node11->input_size() == 1)
2383 {
2384 shape11 = get_node_attr_ai(*node11, "shape");
2385 }
2386 else
2387 {
2388 // skip weight reshape
2389 if (weights.find(node11->input(1)) == weights.end())
2390 continue;
2391
2392 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2393 }
2394
2395 if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2396 continue;
2397
2398 if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2399 continue;
2400
2401 int num_heads = embed_dim / shape8[2];
2402
2403 // 1, seqlen, embed_dim
2404 std::vector<int> shape18;
2405 if (node18->input_size() == 1)
2406 {
2407 shape18 = get_node_attr_ai(*node18, "shape");
2408 }
2409 else
2410 {
2411 // skip weight reshape
2412 if (weights.find(node18->input(1)) == weights.end())
2413 continue;
2414
2415 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2416 }
2417
2418 if (shape18.size() != 3)
2419 continue;
2420
2421 if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2422 continue;
2423
2424 // reduce
2425 node->set_op_type("noop_reducedncnn");
2426 node2->set_op_type("noop_reducedncnn");
2427 node3->set_op_type("noop_reducedncnn");
2428 node4->set_op_type("noop_reducedncnn");
2429 node5->set_op_type("noop_reducedncnn");
2430 node6->set_op_type("noop_reducedncnn");
2431 node7->set_op_type("noop_reducedncnn");
2432 node8->set_op_type("noop_reducedncnn");
2433 node9->set_op_type("noop_reducedncnn");
2434 node10->set_op_type("noop_reducedncnn");
2435 node11->set_op_type("noop_reducedncnn");
2436 node12->set_op_type("noop_reducedncnn");
2437 node13->set_op_type("noop_reducedncnn");
2438 node14->set_op_type("noop_reducedncnn");
2439 node15->set_op_type("noop_reducedncnn");
2440 node16->set_op_type("noop_reducedncnn");
2441 node17->set_op_type("noop_reducedncnn");
2442 node18->set_op_type("noop_reducedncnn");
2443 node19->set_op_type("noop_reducedncnn");
2444
2445 node_reference[node2->input(0)] -= 1;
2446 node_reference[node4->input(0)] -= 1;
2447 node_reference[node6->input(0)] -= 1;
2448 node_reference[node7->input(0)] -= 1;
2449 node_reference[node7->input(1)] -= 1;
2450 node_reference[node8->input(0)] -= 1;
2451 if (node8->input_size() == 2)
2452 {
2453 node_reference[node8->input(1)] -= 1;
2454 }
2455 node_reference[node9->input(0)] -= 1;
2456 node_reference[node10->input(0)] -= 1;
2457 if (node10->input_size() == 2)
2458 {
2459 node_reference[node10->input(1)] -= 1;
2460 }
2461 node_reference[node11->input(0)] -= 1;
2462 if (node11->input_size() == 2)
2463 {
2464 node_reference[node11->input(1)] -= 1;
2465 }
2466 node_reference[node12->input(0)] -= 1;
2467 node_reference[node13->input(0)] -= 1;
2468 node_reference[node14->input(0)] -= 1;
2469 node_reference[node14->input(1)] -= 1;
2470 node_reference[node15->input(0)] -= 1;
2471 node_reference[node16->input(0)] -= 1;
2472 node_reference[node16->input(1)] -= 1;
2473 node_reference[node17->input(0)] -= 1;
2474 node_reference[node18->input(0)] -= 1;
2475 if (node18->input_size() == 2)
2476 {
2477 node_reference[node18->input(1)] -= 1;
2478 }
2479 node_reference[node19->input(0)] -= 1;
2480 node_reference[node20->input(0)] -= 1;
2481
2482 blob_names.erase(node->output(0));
2483 blob_names.erase(node2->output(0));
2484 blob_names.erase(node3->output(0));
2485 blob_names.erase(node4->output(0));
2486 blob_names.erase(node5->output(0));
2487 blob_names.erase(node6->output(0));
2488 blob_names.erase(node7->output(0));
2489 blob_names.erase(node8->output(0));
2490 blob_names.erase(node9->output(0));
2491 blob_names.erase(node10->output(0));
2492 blob_names.erase(node11->output(0));
2493 blob_names.erase(node12->output(0));
2494 blob_names.erase(node13->output(0));
2495 blob_names.erase(node14->output(0));
2496 blob_names.erase(node15->output(0));
2497 blob_names.erase(node16->output(0));
2498 blob_names.erase(node17->output(0));
2499 blob_names.erase(node18->output(0));
2500 blob_names.erase(node19->output(0));
2501
2502 std::string qw = node->input(1);
2503 std::string qb = node2->input(1);
2504 std::string kw = node3->input(1);
2505 std::string kb = node4->input(1);
2506 std::string vw = node5->input(1);
2507 std::string vb = node6->input(1);
2508 std::string ow = node19->input(1);
2509 std::string ob = node20->input(1);
2510
2511 node20->set_op_type("MultiHeadAttention");
2512 node20->clear_input();
2513 node20->add_input(node->input(0));
2514 node20->add_input(node3->input(0));
2515 node20->add_input(node5->input(0));
2516 // q
2517 node20->add_input(qw);
2518 node20->add_input(qb);
2519 // k
2520 node20->add_input(kw);
2521 node20->add_input(kb);
2522 // v
2523 node20->add_input(vw);
2524 node20->add_input(vb);
2525 // out linear
2526 node20->add_input(ow);
2527 node20->add_input(ob);
2528
2529 onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2530 attr_embed_dim->set_name("embed_dim");
2531 attr_embed_dim->set_i(embed_dim);
2532
2533 onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2534 attr_num_heads->set_name("num_heads");
2535 attr_num_heads->set_i(num_heads);
2536
2537 reduced_node_count += 19;
2538 i += 19;
2539 }
2540 }
2541
2542 for (int i = 0; i < node_count; i++)
2543 {
2544 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2545
2546 // MultiHeadAttention <= MatMul(qkv) - Add - Split
2547 // - Mul
2548 // - Reshape - Transpose
2549 // - Reshape - Reshape - Transpose - Transpose
2550 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2551 if (node->op_type() == "MatMul")
2552 {
2553 if (i + 16 >= node_count)
2554 continue;
2555
2556 if (node_reference[node->output(0)] != 1)
2557 continue;
2558
2559 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2560 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2561 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2562 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2563 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2564 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2565 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2566 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2567 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2568 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2569 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2570 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2571 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2572 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2573 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2574 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2575
2576 if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2577 continue;
2578
2579 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2580 continue;
2581
2582 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2583 continue;
2584
2585 std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2586 std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2587
2588 if (qkv_B.size() != o_B.size() * 3)
2589 continue;
2590
2591 int embed_dim = o_B.size();
2592
2593 // 1 0 2
2594 std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2595 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2596 if (perm6.size() != 3 || perm9.size() != 3)
2597 continue;
2598
2599 if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2600 continue;
2601
2602 // 1 2 0
2603 std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2604 if (perm10.size() != 3)
2605 continue;
2606
2607 if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2608 continue;
2609
2610 // 1 0 2
2611 std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2612 if (perm14.size() != 3)
2613 continue;
2614
2615 if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2616 continue;
2617
2618 int softmax_axis = get_node_attr_i(*node12, "axis");
2619 if (softmax_axis != 2)
2620 continue;
2621
2622 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2623 std::vector<int> shape5;
2624 std::vector<int> shape7;
2625 std::vector<int> shape8;
2626 if (node5->input_size() == 1)
2627 {
2628 shape5 = get_node_attr_ai(*node5, "shape");
2629 }
2630 else
2631 {
2632 // skip weight reshape
2633 if (weights.find(node5->input(1)) == weights.end())
2634 continue;
2635
2636 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2637 }
2638 if (node7->input_size() == 1)
2639 {
2640 shape7 = get_node_attr_ai(*node7, "shape");
2641 }
2642 else
2643 {
2644 // skip weight reshape
2645 if (weights.find(node7->input(1)) == weights.end())
2646 continue;
2647
2648 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2649 }
2650 if (node8->input_size() == 1)
2651 {
2652 shape8 = get_node_attr_ai(*node8, "shape");
2653 }
2654 else
2655 {
2656 // skip weight reshape
2657 if (weights.find(node8->input(1)) == weights.end())
2658 continue;
2659
2660 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2661 }
2662
2663 if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2664 continue;
2665
2666 if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2667 continue;
2668
2669 int num_heads = embed_dim / shape5[2];
2670
2671 // 1, seqlen, embed_dim
2672 std::vector<int> shape15;
2673 if (node15->input_size() == 1)
2674 {
2675 shape15 = get_node_attr_ai(*node15, "shape");
2676 }
2677 else
2678 {
2679 // skip weight reshape
2680 if (weights.find(node15->input(1)) == weights.end())
2681 continue;
2682
2683 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2684 }
2685
2686 if (shape15.size() != 3)
2687 continue;
2688
2689 if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2690 continue;
2691
2692 // reduce
2693 node->set_op_type("noop_reducedncnn");
2694 node2->set_op_type("noop_reducedncnn");
2695 node3->set_op_type("noop_reducedncnn");
2696 node4->set_op_type("noop_reducedncnn");
2697 node5->set_op_type("noop_reducedncnn");
2698 node6->set_op_type("noop_reducedncnn");
2699 node7->set_op_type("noop_reducedncnn");
2700 node8->set_op_type("noop_reducedncnn");
2701 node9->set_op_type("noop_reducedncnn");
2702 node10->set_op_type("noop_reducedncnn");
2703 node11->set_op_type("noop_reducedncnn");
2704 node12->set_op_type("noop_reducedncnn");
2705 node13->set_op_type("noop_reducedncnn");
2706 node14->set_op_type("noop_reducedncnn");
2707 node15->set_op_type("noop_reducedncnn");
2708 node16->set_op_type("noop_reducedncnn");
2709
2710 node_reference[node2->input(0)] -= 1;
2711 node_reference[node3->input(0)] -= 1;
2712 node_reference[node4->input(0)] -= 1;
2713 node_reference[node4->input(1)] -= 1;
2714 node_reference[node5->input(0)] -= 1;
2715 if (node5->input_size() == 2)
2716 {
2717 node_reference[node5->input(1)] -= 1;
2718 }
2719 node_reference[node6->input(0)] -= 1;
2720 node_reference[node7->input(0)] -= 1;
2721 if (node7->input_size() == 2)
2722 {
2723 node_reference[node7->input(1)] -= 1;
2724 }
2725 node_reference[node8->input(0)] -= 1;
2726 if (node8->input_size() == 2)
2727 {
2728 node_reference[node8->input(1)] -= 1;
2729 }
2730 node_reference[node9->input(0)] -= 1;
2731 node_reference[node10->input(0)] -= 1;
2732 node_reference[node11->input(0)] -= 1;
2733 node_reference[node11->input(1)] -= 1;
2734 node_reference[node12->input(0)] -= 1;
2735 node_reference[node13->input(0)] -= 1;
2736 node_reference[node13->input(1)] -= 1;
2737 node_reference[node14->input(0)] -= 1;
2738 node_reference[node15->input(0)] -= 1;
2739 if (node15->input_size() == 2)
2740 {
2741 node_reference[node15->input(1)] -= 1;
2742 }
2743 node_reference[node16->input(0)] -= 1;
2744 node_reference[node17->input(0)] -= 1;
2745
2746 blob_names.erase(node->output(0));
2747 blob_names.erase(node2->output(0));
2748 blob_names.erase(node3->output(0));
2749 blob_names.erase(node3->output(1));
2750 blob_names.erase(node3->output(2));
2751 blob_names.erase(node4->output(0));
2752 blob_names.erase(node5->output(0));
2753 blob_names.erase(node6->output(0));
2754 blob_names.erase(node7->output(0));
2755 blob_names.erase(node8->output(0));
2756 blob_names.erase(node9->output(0));
2757 blob_names.erase(node10->output(0));
2758 blob_names.erase(node11->output(0));
2759 blob_names.erase(node12->output(0));
2760 blob_names.erase(node13->output(0));
2761 blob_names.erase(node14->output(0));
2762 blob_names.erase(node15->output(0));
2763 blob_names.erase(node16->output(0));
2764
2765 std::string qkvw = node->input(1);
2766 std::string qkvb = node2->input(1);
2767 std::string ow = node16->input(1);
2768 std::string ob = node17->input(1);
2769
2770 node17->set_op_type("MultiHeadAttention");
2771 node17->clear_input();
2772 node17->add_input(node->input(0));
2773 // qkv
2774 node17->add_input(qkvw);
2775 node17->add_input(qkvb);
2776 // out linear
2777 node17->add_input(ow);
2778 node17->add_input(ob);
2779
2780 onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2781 attr_embed_dim->set_name("embed_dim");
2782 attr_embed_dim->set_i(embed_dim);
2783
2784 onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2785 attr_num_heads->set_name("num_heads");
2786 attr_num_heads->set_i(num_heads);
2787
2788 reduced_node_count += 16;
2789 i += 16;
2790 }
2791 }
2792 }
2793
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2794 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2795 {
2796 int node_count = mutable_graph->node_size();
2797 for (int i = 0; i < node_count; i++)
2798 {
2799 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2800
2801 // Add/Sub/Mul/Div/Min/Max/Pow
2802 if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2803 {
2804 if (weights.find(node->input(1)) == weights.end())
2805 continue;
2806
2807 const onnx::TensorProto& scalar_b = weights[node->input(1)];
2808 if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2809 continue;
2810
2811 float b = get_node_attr_from_input_f(scalar_b);
2812
2813 node_reference[node->input(1)] -= 1;
2814
2815 std::string input = node->input(0);
2816
2817 node->clear_input();
2818 node->add_input(input);
2819
2820 onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2821 attr_with_scalar->set_name("with_scalar");
2822 attr_with_scalar->set_i(1);
2823
2824 onnx::AttributeProto* attr_b = node->add_attribute();
2825 attr_b->set_name("b");
2826 attr_b->set_f(b);
2827 }
2828 }
2829 }
2830
main(int argc,char ** argv)2831 int main(int argc, char** argv)
2832 {
2833 const char* onnxpb = argv[1];
2834 const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
2835 const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
2836
2837 onnx::ModelProto model;
2838
2839 // load
2840 bool s1 = read_proto_from_binary(onnxpb, &model);
2841 if (!s1)
2842 {
2843 fprintf(stderr, "read_proto_from_binary failed\n");
2844 return -1;
2845 }
2846
2847 FILE* pp = fopen(ncnn_prototxt, "wb");
2848 FILE* bp = fopen(ncnn_modelbin, "wb");
2849
2850 // magic
2851 fprintf(pp, "7767517\n");
2852
2853 const onnx::GraphProto& graph = model.graph();
2854 onnx::GraphProto* mutable_graph = model.mutable_graph();
2855
2856 int node_count = graph.node_size();
2857
2858 // node reference
2859 std::map<std::string, int> node_reference;
2860
2861 // weight node and weight reshape node
2862 std::map<std::string, onnx::TensorProto> weights;
2863
2864 for (int j = 0; j < graph.initializer_size(); j++)
2865 {
2866 const onnx::TensorProto& initializer = graph.initializer(j);
2867
2868 // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2869
2870 weights[initializer.name()] = initializer;
2871 }
2872
2873 // topological sort
2874 {
2875 // name -> producer node index
2876 std::set<std::string> producers;
2877 for (int j = 0; j < graph.input_size(); j++)
2878 {
2879 const std::string& input_name = graph.input(j).name();
2880 producers.insert(input_name);
2881 }
2882
2883 for (int i = 0; i < node_count;)
2884 {
2885 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2886
2887 bool swapnode = false;
2888 std::string missing_input_name;
2889 for (int j = 0; j < (int)node->input_size(); j++)
2890 {
2891 const std::string& input_name = node->input(j);
2892 if (input_name.empty())
2893 continue;
2894
2895 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2896 {
2897 swapnode = true;
2898 missing_input_name = input_name;
2899 break;
2900 }
2901 }
2902
2903 if (!swapnode)
2904 {
2905 for (int j = 0; j < (int)node->output_size(); j++)
2906 {
2907 const std::string& output_name = node->output(j);
2908 if (output_name.empty())
2909 continue;
2910
2911 producers.insert(output_name);
2912 }
2913
2914 i++;
2915 continue;
2916 }
2917
2918 // find node that produce missing_input_name
2919 int q = i + 1;
2920 for (; q < node_count; q++)
2921 {
2922 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2923 bool found = false;
2924 for (int j = 0; j < (int)nodeq->output_size(); j++)
2925 {
2926 const std::string& output_name = nodeq->output(j);
2927 if (output_name == missing_input_name)
2928 {
2929 found = true;
2930 break;
2931 }
2932 }
2933
2934 if (found)
2935 break;
2936 }
2937
2938 if (q == node_count)
2939 {
2940 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2941 return -1;
2942 }
2943
2944 // fprintf(stderr, "swap %d %d\n", i, q);
2945 // swap this node with q
2946 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2947 onnx::NodeProto tmp = *node;
2948 *node = *nodeq;
2949 *nodeq = tmp;
2950 }
2951 }
2952
2953 // global definition line
2954 // [layer count] [blob count]
2955 std::set<std::string> blob_names;
2956 for (int i = 0; i < node_count; i++)
2957 {
2958 const onnx::NodeProto& node = graph.node(i);
2959
2960 const std::string& op = node.op_type();
2961
2962 std::string name = node.name();
2963 if (name.empty())
2964 {
2965 name = node.output(0);
2966 }
2967
2968 if (op == "Constant")
2969 {
2970 onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2971 weights[node.output(0)] = tensor;
2972 }
2973
2974 for (int j = 0; j < (int)node.input_size(); j++)
2975 {
2976 const std::string& input_name = node.input(j);
2977
2978 blob_names.insert(input_name);
2979
2980 if (node_reference.find(input_name) == node_reference.end())
2981 {
2982 node_reference[input_name] = 1;
2983 }
2984 else
2985 {
2986 node_reference[input_name] = node_reference[input_name] + 1;
2987 }
2988 }
2989
2990 if (op == "Dropout")
2991 {
2992 const std::string& output_name = node.output(0);
2993 blob_names.insert(output_name);
2994 node_reference[output_name] = 0;
2995 continue;
2996 }
2997
2998 for (int j = 0; j < (int)node.output_size(); j++)
2999 {
3000 const std::string& output_name = node.output(j);
3001
3002 blob_names.insert(output_name);
3003
3004 node_reference[output_name] = 0;
3005 }
3006 }
3007
3008 // include Input node
3009 int input_node_count = 0;
3010 for (int j = 0; j < graph.input_size(); j++)
3011 {
3012 const std::string& input_name = graph.input(j).name();
3013
3014 // check weight
3015 if (weights.find(input_name) != weights.end())
3016 continue;
3017
3018 blob_names.insert(input_name);
3019
3020 input_node_count++;
3021 }
3022
3023 // for (auto a: node_reference)
3024 // {
3025 // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3026 // }
3027
3028 // op chain fusion
3029 int reduced_node_count = 0;
3030 fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3031 fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3032 fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3033 fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3034 fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3035 fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3036 fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3037 fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3038 fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3039 fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3040 fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3041 fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3042 fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3043 fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3044 fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3045 fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3046 fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3047 fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3048 fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3049
3050 // reduce common const weight node_reference
3051 for (int i = 0; i < node_count; i++)
3052 {
3053 const onnx::NodeProto& node = graph.node(i);
3054
3055 const std::string& op = node.op_type();
3056
3057 if (op == "BatchNormalization")
3058 {
3059 node_reference[node.input(1)] -= 1;
3060 node_reference[node.input(2)] -= 1;
3061 node_reference[node.input(3)] -= 1;
3062 node_reference[node.input(4)] -= 1;
3063 }
3064 else if (op == "BiasGelu")
3065 {
3066 node_reference[node.input(1)] -= 1;
3067 }
3068 else if (op == "Clip")
3069 {
3070 if (node.input_size() == 3)
3071 {
3072 node_reference[node.input(1)] -= 1;
3073 node_reference[node.input(2)] -= 1;
3074 }
3075 }
3076 else if (op == "Conv")
3077 {
3078 node_reference[node.input(1)] -= 1;
3079 if (node.input_size() == 3)
3080 {
3081 node_reference[node.input(2)] -= 1;
3082 }
3083 }
3084 else if (op == "ConvTranspose")
3085 {
3086 node_reference[node.input(1)] -= 1;
3087 if (node.input_size() == 3)
3088 {
3089 node_reference[node.input(2)] -= 1;
3090 }
3091 }
3092 else if (op == "EmbedLayerNormalization")
3093 {
3094 node_reference[node.input(1)] -= 1;
3095 node_reference[node.input(2)] -= 1;
3096 node_reference[node.input(3)] -= 1;
3097 node_reference[node.input(4)] -= 1;
3098 node_reference[node.input(5)] -= 1;
3099 node_reference[node.input(6)] -= 1;
3100 }
3101 else if (op == "Gemm")
3102 {
3103 float alpha = get_node_attr_f(node, "alpha", 1.f);
3104 float beta = get_node_attr_f(node, "beta", 1.f);
3105 int transA = get_node_attr_i(node, "transA", 0);
3106 int transB = get_node_attr_i(node, "transB", 0);
3107
3108 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3109 {
3110 // InnerProduct-like A * B + C
3111 node_reference[node.input(1)] -= 1;
3112 node_reference[node.input(2)] -= 1;
3113 }
3114 }
3115 else if (op == "GroupNorm")
3116 {
3117 int affine = get_node_attr_i(node, "affine", 1);
3118 if (affine)
3119 {
3120 node_reference[node.input(1)] -= 1;
3121 node_reference[node.input(2)] -= 1;
3122 }
3123 }
3124 else if (op == "GRU")
3125 {
3126 for (int j = 1; j < node.input_size(); j++)
3127 {
3128 node_reference[node.input(j)] -= 1;
3129 }
3130 }
3131 else if (op == "InstanceNormalization")
3132 {
3133 node_reference[node.input(1)] -= 1;
3134 node_reference[node.input(2)] -= 1;
3135 }
3136 else if (op == "LayerNorm")
3137 {
3138 int affine = get_node_attr_i(node, "affine", 1);
3139 if (affine)
3140 {
3141 node_reference[node.input(1)] -= 1;
3142 node_reference[node.input(2)] -= 1;
3143 }
3144 }
3145 else if (op == "LSTM")
3146 {
3147 for (int j = 1; j < node.input_size(); j++)
3148 {
3149 node_reference[node.input(j)] -= 1;
3150 }
3151 }
3152 else if (op == "MatMul")
3153 {
3154 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3155 {
3156 // InnerProduct
3157 node_reference[node.input(1)] -= 1;
3158 }
3159 }
3160 else if (op == "MultiHeadAttention")
3161 {
3162 if (node.input_size() == 5)
3163 {
3164 node_reference[node.input(1)] -= 1;
3165 node_reference[node.input(2)] -= 1;
3166 node_reference[node.input(3)] -= 1;
3167 node_reference[node.input(4)] -= 1;
3168 }
3169 else
3170 {
3171 node_reference[node.input(3)] -= 1;
3172 node_reference[node.input(4)] -= 1;
3173 node_reference[node.input(5)] -= 1;
3174 node_reference[node.input(6)] -= 1;
3175 node_reference[node.input(7)] -= 1;
3176 node_reference[node.input(8)] -= 1;
3177 node_reference[node.input(9)] -= 1;
3178 node_reference[node.input(10)] -= 1;
3179 }
3180 }
3181 else if (op == "Pad")
3182 {
3183 if (node.input_size() >= 2)
3184 {
3185 node_reference[node.input(1)] -= 1;
3186 }
3187 }
3188 else if (op == "PRelu")
3189 {
3190 node_reference[node.input(1)] -= 1;
3191 }
3192 else if (op == "Reshape")
3193 {
3194 if (node.input_size() >= 2)
3195 {
3196 node_reference[node.input(1)] -= 1;
3197 }
3198 }
3199 else if (op == "Resize")
3200 {
3201 if (node.input_size() == 2)
3202 {
3203 // opset 10
3204 node_reference[node.input(1)] -= 1;
3205 }
3206 else
3207 {
3208 // opset 11+
3209 node_reference[node.input(1)] -= 1;
3210 node_reference[node.input(2)] -= 1;
3211 if (node.input_size() >= 4)
3212 {
3213 node_reference[node.input(3)] -= 1;
3214 }
3215 }
3216 }
3217 else if (op == "RNN")
3218 {
3219 for (int j = 1; j < node.input_size(); j++)
3220 {
3221 node_reference[node.input(j)] -= 1;
3222 }
3223 }
3224 else if (op == "SkipLayerNormalization")
3225 {
3226 node_reference[node.input(2)] -= 1;
3227 node_reference[node.input(3)] -= 1;
3228 node_reference[node.input(4)] -= 1;
3229 }
3230 else if (op == "Slice")
3231 {
3232 if (node.input_size() >= 2)
3233 {
3234 node_reference[node.input(1)] -= 1;
3235 node_reference[node.input(2)] -= 1;
3236 if (node.input_size() >= 4)
3237 node_reference[node.input(3)] -= 1;
3238 if (node.input_size() >= 5)
3239 node_reference[node.input(4)] -= 1;
3240 }
3241 }
3242 else if (op == "Upsample")
3243 {
3244 if (node.input_size() >= 2)
3245 {
3246 node_reference[node.input(1)] -= 1;
3247 }
3248 }
3249 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3250 {
3251 if (node.input_size() >= 2)
3252 {
3253 node_reference[node.input(1)] -= 1;
3254 }
3255 }
3256 }
3257
3258 // for (auto a: node_reference)
3259 // {
3260 // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3261 // }
3262
3263 // count all weight node with zero reference
3264 int zero_reference_weight_node_count = 0;
3265 for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3266 {
3267 const std::string& input_name = it->first;
3268
3269 int refcount = node_reference[input_name];
3270 if (refcount == 0)
3271 zero_reference_weight_node_count++;
3272 }
3273
3274 // we always treat constant node as weight or binaryop_weights
3275 // do not count it twice for layer_count
3276 int constant_node_count_moved_to_weight = 0;
3277 for (int i = 0; i < node_count; i++)
3278 {
3279 const onnx::NodeProto& node = graph.node(i);
3280
3281 const std::string& op = node.op_type();
3282
3283 if (op == "Constant")
3284 {
3285 constant_node_count_moved_to_weight++;
3286 }
3287 }
3288
3289 // some op may have anonymous input
3290 // LSTM sequence_lens
3291 blob_names.erase("");
3292 node_reference.erase("");
3293
3294 // remove node_reference entry with reference equals to one
3295 int split_layer_count = 0;
3296 int splitncnn_blob_count = 0;
3297 // split node reference
3298 std::map<std::string, int> split_node_reference;
3299 for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3300 {
3301 if (it->second > 1)
3302 {
3303 split_layer_count++;
3304 splitncnn_blob_count += it->second;
3305
3306 split_node_reference[it->first] = it->second;
3307 }
3308 }
3309
3310 fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3311
3312 int internal_split = 0;
3313
3314 // place Input at the beginning
3315 for (int j = 0; j < graph.input_size(); j++)
3316 {
3317 const std::string& input_name = graph.input(j).name();
3318
3319 // check weight
3320 if (weights.find(input_name) != weights.end())
3321 continue;
3322
3323 fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3324
3325 int refcount = node_reference[input_name];
3326 if (refcount <= 1)
3327 {
3328 continue;
3329 }
3330
3331 char splitname[256];
3332 sprintf(splitname, "splitncnn_input%d", j);
3333 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3334 fprintf(pp, " %s", input_name.c_str());
3335
3336 for (int k = 0; k < refcount; k++)
3337 {
3338 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3339 }
3340 fprintf(pp, "\n");
3341 }
3342
3343 // place MemoryData next
3344 for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3345 {
3346 const std::string& input_name = weight_it->first;
3347
3348 int refcount = node_reference[input_name];
3349 if (refcount == 0)
3350 {
3351 continue;
3352 }
3353
3354 fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3355
3356 const onnx::TensorProto& M = weights[input_name];
3357
3358 if (M.dims_size() == 0)
3359 {
3360 fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3361 }
3362 else if (M.dims_size() == 1)
3363 {
3364 fprintf(pp, " 0=%d", (int)M.dims(0));
3365 }
3366 else if (M.dims_size() == 2)
3367 {
3368 fprintf(pp, " 0=%d", (int)M.dims(1));
3369 if (M.dims(0) != 1)
3370 {
3371 fprintf(pp, " 1=%d", (int)M.dims(0));
3372 }
3373 }
3374 else if (M.dims_size() == 3)
3375 {
3376 fprintf(pp, " 0=%d", (int)M.dims(2));
3377 fprintf(pp, " 1=%d", (int)M.dims(1));
3378 if (M.dims(0) != 1)
3379 {
3380 fprintf(pp, " 2=%d", (int)M.dims(0));
3381 }
3382 }
3383 else if (M.dims_size() == 4)
3384 {
3385 fprintf(pp, " 0=%d", (int)M.dims(3));
3386 fprintf(pp, " 1=%d", (int)M.dims(2));
3387 fprintf(pp, " 2=%d", (int)M.dims(1));
3388 }
3389
3390 fprintf(pp, "\n");
3391
3392 fwrite_tensor_proto_data(M, bp);
3393
3394 if (refcount <= 1)
3395 {
3396 continue;
3397 }
3398
3399 char splitname[256];
3400 sprintf(splitname, "splitncnn_%d", internal_split);
3401 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3402
3403 fprintf(pp, " %s", input_name.c_str());
3404
3405 for (int k = 0; k < refcount; k++)
3406 {
3407 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3408 }
3409 fprintf(pp, "\n");
3410
3411 internal_split++;
3412 }
3413
3414 for (int i = 0; i < node_count; i++)
3415 {
3416 const onnx::NodeProto& node = graph.node(i);
3417
3418 const std::string& op = node.op_type();
3419
3420 // fprintf(stderr, "op = %s\n", op.c_str());
3421
3422 if (op == "noop_reducedncnn")
3423 {
3424 continue;
3425 }
3426
3427 std::string name = node.name();
3428 if (name.empty())
3429 {
3430 name = node.output(0);
3431 }
3432
3433 int input_size = node.input_size();
3434 int output_size = node.output_size();
3435
3436 for (int j = 0; j < (int)node.input_size(); j++)
3437 {
3438 const std::string& input_name = node.input(j);
3439
3440 // check weight
3441 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3442 {
3443 input_size--;
3444 }
3445
3446 if (input_name.empty())
3447 {
3448 input_size--;
3449 }
3450
3451 // fprintf(stderr, " input = %s\n", input_name.c_str());
3452 }
3453 /*
3454 for (int j=0; j<(int)node.output_size(); j++)
3455 {
3456 const std::string& output_name = node.output(j);
3457 fprintf(stderr, " output = %s\n", output_name.c_str());
3458 }
3459 */
3460
3461 if (op == "Abs")
3462 {
3463 fprintf(pp, "%-16s", "UnaryOp");
3464 }
3465 else if (op == "Acos")
3466 {
3467 fprintf(pp, "%-16s", "UnaryOp");
3468 }
3469 else if (op == "Add")
3470 {
3471 fprintf(pp, "%-16s", "BinaryOp");
3472 }
3473 else if (op == "Asin")
3474 {
3475 fprintf(pp, "%-16s", "UnaryOp");
3476 }
3477 else if (op == "Atan")
3478 {
3479 fprintf(pp, "%-16s", "UnaryOp");
3480 }
3481 else if (op == "AveragePool" || op == "MaxPool")
3482 {
3483 fprintf(pp, "%-16s", "Pooling");
3484 }
3485 else if (op == "BatchNormalization")
3486 {
3487 fprintf(pp, "%-16s", "BatchNorm");
3488 }
3489 else if (op == "BiasGelu")
3490 {
3491 fprintf(pp, "%-16s", "BiasGelu");
3492 }
3493 else if (op == "Ceil")
3494 {
3495 fprintf(pp, "%-16s", "UnaryOp");
3496 }
3497 else if (op == "Clip")
3498 {
3499 fprintf(pp, "%-16s", "Clip");
3500 }
3501 else if (op == "Concat")
3502 {
3503 fprintf(pp, "%-16s", "Concat");
3504 }
3505 else if (op == "Constant")
3506 {
3507 continue;
3508 }
3509 else if (op == "Conv")
3510 {
3511 int group = get_node_attr_i(node, "group", 1);
3512 if (group > 1)
3513 {
3514 fprintf(pp, "%-16s", "ConvolutionDepthWise");
3515 }
3516 else
3517 {
3518 fprintf(pp, "%-16s", "Convolution");
3519 }
3520 }
3521 else if (op == "ConvTranspose")
3522 {
3523 int group = get_node_attr_i(node, "group", 1);
3524 if (group > 1)
3525 {
3526 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3527 }
3528 else
3529 {
3530 fprintf(pp, "%-16s", "Deconvolution");
3531 }
3532 }
3533 else if (op == "Cos")
3534 {
3535 fprintf(pp, "%-16s", "UnaryOp");
3536 }
3537 else if (op == "DepthToSpace")
3538 {
3539 fprintf(pp, "%-16s", "PixelShuffle");
3540 }
3541 else if (op == "Div")
3542 {
3543 fprintf(pp, "%-16s", "BinaryOp");
3544 }
3545 else if (op == "Dropout")
3546 {
3547 fprintf(pp, "%-16s", "Dropout");
3548 output_size = 1;
3549 }
3550 else if (op == "Elu")
3551 {
3552 fprintf(pp, "%-16s", "ELU");
3553 }
3554 else if (op == "EmbedLayerNormalization")
3555 {
3556 fprintf(pp, "%-16s", "EmbedLayerNormalization");
3557 }
3558 else if (op == "Exp")
3559 {
3560 fprintf(pp, "%-16s", "UnaryOp");
3561 }
3562 else if (op == "Flatten")
3563 {
3564 fprintf(pp, "%-16s", "Flatten");
3565 }
3566 else if (op == "Floor")
3567 {
3568 fprintf(pp, "%-16s", "UnaryOp");
3569 }
3570 else if (op == "Gemm")
3571 {
3572 float alpha = get_node_attr_f(node, "alpha", 1.f);
3573 float beta = get_node_attr_f(node, "beta", 1.f);
3574 int transA = get_node_attr_i(node, "transA", 0);
3575 int transB = get_node_attr_i(node, "transB", 0);
3576
3577 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3578 {
3579 // InnerProduct-like A * B + C
3580 fprintf(pp, "%-16s", "InnerProduct");
3581 }
3582 else
3583 {
3584 fprintf(pp, "%-16s", "Gemm");
3585 }
3586 }
3587 else if (op == "GlobalAveragePool")
3588 {
3589 fprintf(pp, "%-16s", "Pooling");
3590 }
3591 else if (op == "GlobalMaxPool")
3592 {
3593 fprintf(pp, "%-16s", "Pooling");
3594 }
3595 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3596 {
3597 fprintf(pp, "%-16s", "Pooling");
3598 }
3599 else if (op == "GroupNorm")
3600 {
3601 fprintf(pp, "%-16s", "GroupNorm");
3602 }
3603 else if (op == "GRU")
3604 {
3605 fprintf(pp, "%-16s", "GRU");
3606 }
3607 else if (op == "HardSigmoid")
3608 {
3609 fprintf(pp, "%-16s", "HardSigmoid");
3610 }
3611 else if (op == "HardSwish")
3612 {
3613 fprintf(pp, "%-16s", "HardSwish");
3614 }
3615 else if (op == "ImageScaler")
3616 {
3617 fprintf(pp, "%-16s", "Scale");
3618 }
3619 else if (op == "InstanceNormalization")
3620 {
3621 fprintf(pp, "%-16s", "InstanceNorm");
3622 }
3623 else if (op == "LayerNorm")
3624 {
3625 fprintf(pp, "%-16s", "LayerNorm");
3626 }
3627 else if (op == "LeakyRelu")
3628 {
3629 fprintf(pp, "%-16s", "ReLU");
3630 }
3631 else if (op == "Log")
3632 {
3633 fprintf(pp, "%-16s", "UnaryOp");
3634 }
3635 else if (op == "LRN")
3636 {
3637 fprintf(pp, "%-16s", "LRN");
3638 }
3639 else if (op == "LSTM")
3640 {
3641 fprintf(pp, "%-16s", "LSTM");
3642 }
3643 else if (op == "MatMul")
3644 {
3645 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3646 {
3647 fprintf(pp, "%-16s", "InnerProduct");
3648 }
3649 else
3650 {
3651 fprintf(pp, "%-16s", "Gemm");
3652 }
3653 }
3654 else if (op == "Max")
3655 {
3656 fprintf(pp, "%-16s", "BinaryOp");
3657 }
3658 else if (op == "Min")
3659 {
3660 fprintf(pp, "%-16s", "BinaryOp");
3661 }
3662 else if (op == "Mul")
3663 {
3664 fprintf(pp, "%-16s", "BinaryOp");
3665 }
3666 else if (op == "MultiHeadAttention")
3667 {
3668 fprintf(pp, "%-16s", "MultiHeadAttention");
3669 }
3670 else if (op == "Neg")
3671 {
3672 fprintf(pp, "%-16s", "UnaryOp");
3673 }
3674 else if (op == "Normalize")
3675 {
3676 fprintf(pp, "%-16s", "Normalize");
3677 }
3678 else if (op == "Pad")
3679 {
3680 fprintf(pp, "%-16s", "Padding");
3681 }
3682 else if (op == "PixelShuffle")
3683 {
3684 fprintf(pp, "%-16s", "PixelShuffle");
3685 }
3686 else if (op == "Pow")
3687 {
3688 fprintf(pp, "%-16s", "BinaryOp");
3689 }
3690 else if (op == "PRelu")
3691 {
3692 fprintf(pp, "%-16s", "PReLU");
3693 }
3694 else if (op == "Reciprocal")
3695 {
3696 fprintf(pp, "%-16s", "UnaryOp");
3697 }
3698 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3699 {
3700 fprintf(pp, "%-16s", "Reduction");
3701 }
3702 else if (op == "Relu")
3703 {
3704 fprintf(pp, "%-16s", "ReLU");
3705 }
3706 else if (op == "Reorg")
3707 {
3708 fprintf(pp, "%-16s", "Reorg");
3709 }
3710 else if (op == "Reshape")
3711 {
3712 fprintf(pp, "%-16s", "Reshape");
3713 }
3714 else if (op == "RNN")
3715 {
3716 fprintf(pp, "%-16s", "RNN");
3717 }
3718 else if (op == "ShuffleChannel")
3719 {
3720 fprintf(pp, "%-16s", "ShuffleChannel");
3721 }
3722 else if (op == "Sigmoid")
3723 {
3724 fprintf(pp, "%-16s", "Sigmoid");
3725 }
3726 else if (op == "Sin")
3727 {
3728 fprintf(pp, "%-16s", "UnaryOp");
3729 }
3730 else if (op == "SkipLayerNormalization")
3731 {
3732 fprintf(pp, "%-16s", "SkipLayerNormalization");
3733 }
3734 else if (op == "Slice")
3735 {
3736 fprintf(pp, "%-16s", "Crop");
3737 }
3738 else if (op == "Softmax")
3739 {
3740 fprintf(pp, "%-16s", "Softmax");
3741 }
3742 else if (op == "Softplus")
3743 {
3744 fprintf(pp, "%-16s", "Softplus");
3745 }
3746 else if (op == "Split")
3747 {
3748 fprintf(pp, "%-16s", "Slice");
3749 }
3750 else if (op == "Sqrt")
3751 {
3752 fprintf(pp, "%-16s", "UnaryOp");
3753 }
3754 else if (op == "Squeeze")
3755 {
3756 fprintf(pp, "%-16s", "Squeeze");
3757 }
3758 else if (op == "Sub")
3759 {
3760 fprintf(pp, "%-16s", "BinaryOp");
3761 }
3762 else if (op == "Sum")
3763 {
3764 fprintf(pp, "%-16s", "Eltwise");
3765 }
3766 else if (op == "Swish")
3767 {
3768 fprintf(pp, "%-16s", "Swish");
3769 }
3770 else if (op == "Tan")
3771 {
3772 fprintf(pp, "%-16s", "UnaryOp");
3773 }
3774 else if (op == "Tanh")
3775 {
3776 fprintf(pp, "%-16s", "UnaryOp");
3777 }
3778 else if (op == "Transpose")
3779 {
3780 fprintf(pp, "%-16s", "Permute");
3781 }
3782 else if (op == "Upsample" || op == "Resize")
3783 {
3784 fprintf(pp, "%-16s", "Interp");
3785 }
3786 else if (op == "Unsqueeze")
3787 {
3788 fprintf(pp, "%-16s", "ExpandDims");
3789 }
3790 else
3791 {
3792 // TODO
3793 fprintf(stderr, "%s not supported yet!\n", op.c_str());
3794 fprintf(pp, "%-16s", op.c_str());
3795 }
3796
3797 fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3798
3799 for (int j = 0; j < (int)node.input_size(); j++)
3800 {
3801 std::string input_name = node.input(j);
3802
3803 // check weight
3804 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3805 {
3806 continue;
3807 }
3808
3809 if (input_name.empty())
3810 {
3811 continue;
3812 }
3813
3814 if (split_node_reference.find(input_name) != split_node_reference.end())
3815 {
3816 int refidx = split_node_reference[input_name] - 1;
3817 split_node_reference[input_name] = refidx;
3818
3819 char splitsuffix[256];
3820 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3821 input_name = input_name + splitsuffix;
3822 }
3823
3824 fprintf(pp, " %s", input_name.c_str());
3825 }
3826
3827 for (int j = 0; j < output_size; j++)
3828 {
3829 const std::string& output_name = node.output(j);
3830
3831 fprintf(pp, " %s", output_name.c_str());
3832 }
3833
3834 if (op == "Abs")
3835 {
3836 int op_type = 0;
3837 fprintf(pp, " 0=%d", op_type);
3838 }
3839 else if (op == "Acos")
3840 {
3841 int op_type = 13;
3842 fprintf(pp, " 0=%d", op_type);
3843 }
3844 else if (op == "Add")
3845 {
3846 int op_type = 0;
3847 fprintf(pp, " 0=%d", op_type);
3848
3849 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3850 float b = get_node_attr_f(node, "b", 0.f);
3851 if (with_scalar)
3852 {
3853 fprintf(pp, " 1=%d", with_scalar);
3854 fprintf(pp, " 2=%e", b);
3855 }
3856 }
3857 else if (op == "Asin")
3858 {
3859 int op_type = 12;
3860 fprintf(pp, " 0=%d", op_type);
3861 }
3862 else if (op == "Atan")
3863 {
3864 int op_type = 14;
3865 fprintf(pp, " 0=%d", op_type);
3866 }
3867 else if (op == "AveragePool" || op == "MaxPool")
3868 {
3869 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3870 int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3871 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3872 std::vector<int> strides = get_node_attr_ai(node, "strides");
3873 std::vector<int> pads = get_node_attr_ai(node, "pads");
3874
3875 int pool = op == "AveragePool" ? 1 : 0;
3876 int pad_mode = 1;
3877
3878 if (auto_pad == "SAME_UPPER")
3879 {
3880 pad_mode = 2;
3881 }
3882 else if (auto_pad == "SAME_LOWER")
3883 {
3884 pad_mode = 3;
3885 }
3886
3887 if (ceil_mode == 1)
3888 {
3889 pad_mode = 0;
3890 }
3891
3892 fprintf(pp, " 0=%d", pool);
3893
3894 if (kernel_shape.size() == 1)
3895 {
3896 fprintf(pp, " 1=%d", kernel_shape[0]);
3897 }
3898 else if (kernel_shape.size() == 2)
3899 {
3900 fprintf(pp, " 1=%d", kernel_shape[1]);
3901 fprintf(pp, " 11=%d", kernel_shape[0]);
3902 }
3903
3904 if (strides.size() == 1)
3905 {
3906 fprintf(pp, " 2=%d", strides[0]);
3907 }
3908 else if (strides.size() == 2)
3909 {
3910 fprintf(pp, " 2=%d", strides[1]);
3911 fprintf(pp, " 12=%d", strides[0]);
3912 }
3913
3914 if (pads.size() == 1)
3915 {
3916 fprintf(pp, " 3=%d", pads[0]);
3917 }
3918 else if (pads.size() == 2)
3919 {
3920 fprintf(pp, " 3=%d", pads[1]);
3921 fprintf(pp, " 13=%d", pads[0]);
3922 }
3923 else if (pads.size() == 4)
3924 {
3925 fprintf(pp, " 3=%d", pads[1]);
3926 fprintf(pp, " 13=%d", pads[0]);
3927 fprintf(pp, " 14=%d", pads[3]);
3928 fprintf(pp, " 15=%d", pads[2]);
3929 }
3930
3931 fprintf(pp, " 5=%d", pad_mode);
3932
3933 if (op == "AveragePool")
3934 {
3935 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3936 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3937 }
3938 }
3939 else if (op == "BatchNormalization")
3940 {
3941 float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3942
3943 const onnx::TensorProto& scale = weights[node.input(1)];
3944 const onnx::TensorProto& B = weights[node.input(2)];
3945 const onnx::TensorProto& mean = weights[node.input(3)];
3946 const onnx::TensorProto& var = weights[node.input(4)];
3947
3948 int channels = get_tensor_proto_data_size(scale);
3949
3950 fprintf(pp, " 0=%d", channels);
3951
3952 fwrite_tensor_proto_data(scale, bp);
3953 fwrite_tensor_proto_data(mean, bp);
3954 // apply epsilon to var
3955 {
3956 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3957
3958 for (int j = 0; j < channels; j++)
3959 {
3960 float ve = v[j] + epsilon;
3961 fwrite(&ve, sizeof(float), 1, bp);
3962 }
3963 }
3964 fwrite_tensor_proto_data(B, bp);
3965 }
3966 else if (op == "BiasGelu")
3967 {
3968 const onnx::TensorProto& B = weights[node.input(1)];
3969
3970 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3971
3972 int quantize_tag = 0;
3973 fwrite(&quantize_tag, sizeof(int), 1, bp);
3974
3975 fwrite_tensor_proto_data(B, bp);
3976 }
3977 else if (op == "Ceil")
3978 {
3979 int op_type = 3;
3980 fprintf(pp, " 0=%d", op_type);
3981 }
3982 else if (op == "Clip")
3983 {
3984 float min;
3985 float max;
3986 if (node.input_size() == 1)
3987 {
3988 min = get_node_attr_f(node, "min", -FLT_MAX);
3989 max = get_node_attr_f(node, "max", FLT_MAX);
3990 }
3991 else
3992 {
3993 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
3994 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
3995 }
3996
3997 fprintf(pp, " 0=%e", min);
3998 fprintf(pp, " 1=%e", max);
3999 }
4000 else if (op == "Concat")
4001 {
4002 int axis = get_node_attr_i(node, "axis", 1);
4003 fprintf(pp, " 0=%d", axis - 1);
4004 }
4005 else if (op == "Constant")
4006 {
4007 // never reach here
4008 }
4009 else if (op == "Conv")
4010 {
4011 const onnx::TensorProto& W = weights[node.input(1)];
4012
4013 int num_filter = W.dims(0);
4014 int has_bias = node.input_size() == 3 ? 1 : 0;
4015
4016 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4017 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4018 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4019 std::vector<int> strides = get_node_attr_ai(node, "strides");
4020 std::vector<int> pads = get_node_attr_ai(node, "pads");
4021 int group = get_node_attr_i(node, "group", 1);
4022
4023 fprintf(pp, " 0=%d", num_filter);
4024
4025 if (kernel_shape.size() == 1)
4026 {
4027 fprintf(pp, " 1=%d", kernel_shape[0]);
4028 }
4029 else if (kernel_shape.size() == 2)
4030 {
4031 fprintf(pp, " 1=%d", kernel_shape[1]);
4032 fprintf(pp, " 11=%d", kernel_shape[0]);
4033 }
4034
4035 if (dilations.size() == 1)
4036 {
4037 fprintf(pp, " 2=%d", dilations[0]);
4038 }
4039 else if (dilations.size() == 2)
4040 {
4041 fprintf(pp, " 2=%d", dilations[1]);
4042 fprintf(pp, " 12=%d", dilations[0]);
4043 }
4044
4045 if (strides.size() == 1)
4046 {
4047 fprintf(pp, " 3=%d", strides[0]);
4048 }
4049 else if (strides.size() == 2)
4050 {
4051 fprintf(pp, " 3=%d", strides[1]);
4052 fprintf(pp, " 13=%d", strides[0]);
4053 }
4054
4055 if (auto_pad == "SAME_UPPER")
4056 {
4057 fprintf(pp, " 4=-233");
4058 }
4059 else if (auto_pad == "SAME_LOWER")
4060 {
4061 fprintf(pp, " 4=-234");
4062 }
4063 else
4064 {
4065 if (pads.size() == 1)
4066 {
4067 fprintf(pp, " 4=%d", pads[0]);
4068 }
4069 else if (pads.size() == 2)
4070 {
4071 fprintf(pp, " 4=%d", pads[1]);
4072 fprintf(pp, " 14=%d", pads[0]);
4073 }
4074 else if (pads.size() == 4)
4075 {
4076 fprintf(pp, " 4=%d", pads[1]);
4077 fprintf(pp, " 14=%d", pads[0]);
4078 fprintf(pp, " 15=%d", pads[3]);
4079 fprintf(pp, " 16=%d", pads[2]);
4080 }
4081 }
4082
4083 fprintf(pp, " 5=%d", has_bias);
4084
4085 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4086
4087 if (group > 1)
4088 {
4089 fprintf(pp, " 7=%d", group);
4090 }
4091
4092 int quantize_tag = 0;
4093 fwrite(&quantize_tag, sizeof(int), 1, bp);
4094
4095 fwrite_tensor_proto_data(W, bp);
4096
4097 if (has_bias)
4098 {
4099 const onnx::TensorProto& B = weights[node.input(2)];
4100 fwrite_tensor_proto_data(B, bp);
4101 }
4102 }
4103 else if (op == "ConvTranspose")
4104 {
4105 const onnx::TensorProto& W = weights[node.input(1)];
4106
4107 int has_bias = node.input_size() == 3 ? 1 : 0;
4108
4109 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4110 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4111 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4112 std::vector<int> strides = get_node_attr_ai(node, "strides");
4113 std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4114 std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4115 std::vector<int> pads = get_node_attr_ai(node, "pads");
4116 int group = get_node_attr_i(node, "group", 1);
4117 int num_filter = W.dims(1) * group;
4118
4119 fprintf(pp, " 0=%d", num_filter);
4120
4121 if (kernel_shape.size() == 1)
4122 {
4123 fprintf(pp, " 1=%d", kernel_shape[0]);
4124 }
4125 else if (kernel_shape.size() == 2)
4126 {
4127 fprintf(pp, " 1=%d", kernel_shape[1]);
4128 fprintf(pp, " 11=%d", kernel_shape[0]);
4129 }
4130
4131 if (dilations.size() == 1)
4132 {
4133 fprintf(pp, " 2=%d", dilations[0]);
4134 }
4135 else if (dilations.size() == 2)
4136 {
4137 fprintf(pp, " 2=%d", dilations[1]);
4138 fprintf(pp, " 12=%d", dilations[0]);
4139 }
4140
4141 if (strides.size() == 1)
4142 {
4143 fprintf(pp, " 3=%d", strides[0]);
4144 }
4145 else if (strides.size() == 2)
4146 {
4147 fprintf(pp, " 3=%d", strides[1]);
4148 fprintf(pp, " 13=%d", strides[0]);
4149 }
4150
4151 if (auto_pad == "SAME_UPPER")
4152 {
4153 fprintf(pp, " 4=-233");
4154 }
4155 else if (auto_pad == "SAME_LOWER")
4156 {
4157 fprintf(pp, " 4=-234");
4158 }
4159 else
4160 {
4161 if (pads.size() == 1)
4162 {
4163 fprintf(pp, " 4=%d", pads[0]);
4164 }
4165 else if (pads.size() == 2)
4166 {
4167 fprintf(pp, " 4=%d", pads[1]);
4168 fprintf(pp, " 14=%d", pads[0]);
4169 }
4170 else if (pads.size() == 4)
4171 {
4172 fprintf(pp, " 4=%d", pads[1]);
4173 fprintf(pp, " 14=%d", pads[0]);
4174 fprintf(pp, " 15=%d", pads[3]);
4175 fprintf(pp, " 16=%d", pads[2]);
4176 }
4177 }
4178
4179 if (output_padding.size() == 1)
4180 {
4181 fprintf(pp, " 18=%d", output_padding[0]);
4182 }
4183 else if (output_padding.size() == 2)
4184 {
4185 fprintf(pp, " 18=%d", output_padding[1]);
4186 fprintf(pp, " 19=%d", output_padding[0]);
4187 }
4188
4189 if (output_shape.size() == 1)
4190 {
4191 fprintf(pp, " 20=%d", output_shape[0]);
4192 }
4193 else if (output_shape.size() == 2)
4194 {
4195 fprintf(pp, " 20=%d", output_shape[1]);
4196 fprintf(pp, " 21=%d", output_shape[0]);
4197 }
4198
4199 fprintf(pp, " 5=%d", has_bias);
4200
4201 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4202
4203 if (group > 1)
4204 {
4205 fprintf(pp, " 7=%d", group);
4206 }
4207
4208 int quantize_tag = 0;
4209 fwrite(&quantize_tag, sizeof(int), 1, bp);
4210
4211 int maxk = 0;
4212 if (kernel_shape.size() == 2)
4213 {
4214 maxk = kernel_shape[1] * kernel_shape[0];
4215 }
4216 else
4217 {
4218 maxk = kernel_shape[0] * kernel_shape[0];
4219 }
4220 int weight_data_size = get_tensor_proto_data_size(W);
4221 const float* weight_data = 0;
4222 if (W.has_raw_data())
4223 {
4224 weight_data = (const float*)W.raw_data().data();
4225 }
4226 else if (W.data_type() == 1)
4227 {
4228 weight_data = W.float_data().data();
4229 }
4230 for (int g = 0; g < group; g++)
4231 {
4232 // reorder weight from inch-outch to outch-inch
4233 int num_filter_g = num_filter / group;
4234 int num_input = weight_data_size / maxk / num_filter_g / group;
4235 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4236 for (int k = 0; k < num_filter_g; k++)
4237 {
4238 for (int j = 0; j < num_input; j++)
4239 {
4240 fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4241 }
4242 }
4243 }
4244
4245 if (has_bias)
4246 {
4247 const onnx::TensorProto& B = weights[node.input(2)];
4248 fwrite_tensor_proto_data(B, bp);
4249 }
4250 }
4251 else if (op == "Cos")
4252 {
4253 int op_type = 10;
4254 fprintf(pp, " 0=%d", op_type);
4255 }
4256 else if (op == "DepthToSpace")
4257 {
4258 // pixelshuffle
4259 int scale_factor = get_node_attr_i(node, "blocksize", 1);
4260 std::string mode = get_node_attr_s(node, "mode");
4261 fprintf(pp, " 0=%d", scale_factor);
4262 if (mode == "CRD")
4263 {
4264 fprintf(pp, " 1=0");
4265 }
4266 else if (mode == "DCR")
4267 {
4268 fprintf(pp, " 1=1");
4269 }
4270 }
4271 else if (op == "Div")
4272 {
4273 int op_type = 3;
4274 fprintf(pp, " 0=%d", op_type);
4275
4276 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4277 float b = get_node_attr_f(node, "b", 0.f);
4278 if (with_scalar)
4279 {
4280 fprintf(pp, " 1=%d", with_scalar);
4281 fprintf(pp, " 2=%e", b);
4282 }
4283 }
4284 else if (op == "Dropout")
4285 {
4286 // no-op
4287 }
4288 else if (op == "Elu")
4289 {
4290 float alpha = get_node_attr_f(node, "alpha", 1.f);
4291 fprintf(pp, " 0=%e", alpha);
4292 }
4293 else if (op == "EmbedLayerNormalization")
4294 {
4295 const onnx::TensorProto& words = weights[node.input(2)];
4296 const onnx::TensorProto& positions = weights[node.input(3)];
4297 const onnx::TensorProto& W = weights[node.input(5)];
4298 const onnx::TensorProto& B = weights[node.input(6)];
4299
4300 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4301 fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4302 fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4303
4304 int quantize_tag = 0;
4305 fwrite(&quantize_tag, sizeof(int), 1, bp);
4306
4307 fwrite_tensor_proto_data(words, bp);
4308
4309 fwrite(&quantize_tag, sizeof(int), 1, bp);
4310
4311 fwrite_tensor_proto_data(positions, bp);
4312
4313 fwrite(&quantize_tag, sizeof(int), 1, bp);
4314
4315 fwrite_tensor_proto_data(W, bp);
4316
4317 fwrite(&quantize_tag, sizeof(int), 1, bp);
4318
4319 fwrite_tensor_proto_data(B, bp);
4320 }
4321 else if (op == "Exp")
4322 {
4323 int op_type = 7;
4324 fprintf(pp, " 0=%d", op_type);
4325 }
4326 else if (op == "Flatten")
4327 {
4328 int axis = get_node_attr_i(node, "axis", 1);
4329 if (axis != 1)
4330 {
4331 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4332 }
4333 }
4334 else if (op == "Floor")
4335 {
4336 int op_type = 2;
4337 fprintf(pp, " 0=%d", op_type);
4338 }
4339 else if (op == "Gemm")
4340 {
4341 float alpha = get_node_attr_f(node, "alpha", 1.f);
4342 float beta = get_node_attr_f(node, "beta", 1.f);
4343 int transA = get_node_attr_i(node, "transA", 0);
4344 int transB = get_node_attr_i(node, "transB", 0);
4345
4346 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4347 {
4348 // InnerProduct-like A * B + C
4349 const onnx::TensorProto& B = weights[node.input(1)];
4350 const onnx::TensorProto& C = weights[node.input(2)];
4351
4352 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4353 fprintf(pp, " 1=1");
4354 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4355
4356 int quantize_tag = 0;
4357 fwrite(&quantize_tag, sizeof(int), 1, bp);
4358
4359 fwrite_tensor_proto_data(B, bp);
4360 fwrite_tensor_proto_data(C, bp);
4361 }
4362 else
4363 {
4364 // gemm
4365 fprintf(pp, " 0=%e", alpha);
4366 fprintf(pp, " 1=%e", beta);
4367 fprintf(pp, " 2=%d", transA);
4368 fprintf(pp, " 3=%d", transB);
4369 }
4370 }
4371 else if (op == "GlobalAveragePool")
4372 {
4373 int pool = 1;
4374 int global_pool = 1;
4375
4376 fprintf(pp, " 0=%d", pool);
4377 fprintf(pp, " 4=%d", global_pool);
4378 }
4379 else if (op == "GlobalMaxPool")
4380 {
4381 int pool = 0;
4382 int global_pool = 1;
4383
4384 fprintf(pp, " 0=%d", pool);
4385 fprintf(pp, " 4=%d", global_pool);
4386 }
4387 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4388 {
4389 int pool = 0;
4390 if (op == "adaptive_avg_pool2d")
4391 {
4392 pool = 1;
4393 }
4394 int adaptive_pooling = 1;
4395 const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4396 std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4397
4398 fprintf(pp, " 0=%d", pool);
4399 fprintf(pp, " 7=%d", adaptive_pooling);
4400 if (out_shape.size() == 1)
4401 {
4402 fprintf(pp, " 8=%d", out_shape[0]);
4403 }
4404 else if (out_shape.size() == 2)
4405 {
4406 // out_w
4407 fprintf(pp, " 8=%d", out_shape[1]);
4408 // out_h
4409 fprintf(pp, " 18=%d", out_shape[0]);
4410 }
4411 }
4412 else if (op == "GroupNorm")
4413 {
4414 int groups = get_node_attr_i(node, "groups", 1);
4415 int channels = get_node_attr_i(node, "channels", 1);
4416 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4417 int affine = get_node_attr_i(node, "affine", 1);
4418
4419 if (affine)
4420 {
4421 // discard affine-less S=1 B=0
4422 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4423 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4424 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4425 {
4426 affine = 0;
4427 }
4428 else
4429 {
4430 affine = 0;
4431 {
4432 for (int j = 0; j < channels; j++)
4433 {
4434 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4435 {
4436 affine = 1;
4437 break;
4438 }
4439 }
4440 }
4441 }
4442 }
4443
4444 fprintf(pp, " 0=%d", groups);
4445 fprintf(pp, " 1=%d", channels);
4446 fprintf(pp, " 2=%e", eps);
4447 fprintf(pp, " 3=%d", affine);
4448 if (affine)
4449 {
4450 const onnx::TensorProto& scale = weights[node.input(1)];
4451 const onnx::TensorProto& B = weights[node.input(2)];
4452
4453 fwrite_tensor_proto_data(scale, bp);
4454 fwrite_tensor_proto_data(B, bp);
4455 }
4456 }
4457 else if (op == "GRU")
4458 {
4459 const onnx::TensorProto& W = weights[node.input(1)];
4460 const onnx::TensorProto& R = weights[node.input(2)];
4461 const onnx::TensorProto& B = weights[node.input(3)];
4462
4463 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4464 std::string direction = get_node_attr_s(node, "direction");
4465
4466 int direction_type = 0;
4467 if (direction == "forward")
4468 {
4469 direction_type = 0;
4470 }
4471 else if (direction == "reverse")
4472 {
4473 direction_type = 1;
4474 }
4475 else if (direction == "bidirectional")
4476 {
4477 direction_type = 2;
4478 }
4479
4480 int weight_data_size = get_tensor_proto_data_size(W);
4481
4482 fprintf(pp, " 0=%d", hidden_size);
4483 fprintf(pp, " 1=%d", weight_data_size);
4484 fprintf(pp, " 2=%d", direction_type);
4485
4486 int num_directions = direction_type == 2 ? 2 : 1;
4487
4488 int quantize_tag = 0;
4489
4490 // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4491 {
4492 fwrite(&quantize_tag, sizeof(int), 1, bp);
4493
4494 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4495 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4496
4497 const float* uptr = wptr;
4498 const float* rptr = wptr + weight_data_size_g;
4499 const float* nptr = wptr + weight_data_size_g * 2;
4500 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4501 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4502 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4503
4504 if (direction_type == 2)
4505 {
4506 uptr += weight_data_size_g * 3;
4507 rptr += weight_data_size_g * 3;
4508 nptr += weight_data_size_g * 3;
4509 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4510 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4511 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4512 }
4513 }
4514
4515 // reduce U and R bias except N
4516 // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4517 {
4518 fwrite(&quantize_tag, sizeof(int), 1, bp);
4519
4520 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4521 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4522 const float* wuptr = bptr;
4523 const float* wrptr = bptr + bias_data_size_g;
4524 const float* wnptr = bptr + bias_data_size_g * 2;
4525 const float* buptr = bptr + bias_data_size_g * 3;
4526 const float* brptr = bptr + bias_data_size_g * 4;
4527 const float* bnptr = bptr + bias_data_size_g * 5;
4528
4529 for (int j = 0; j < bias_data_size_g; j++)
4530 {
4531 float vb = wrptr[j] + brptr[j];
4532 fwrite(&vb, sizeof(float), 1, bp);
4533 }
4534 for (int j = 0; j < bias_data_size_g; j++)
4535 {
4536 float vb = wuptr[j] + buptr[j];
4537 fwrite(&vb, sizeof(float), 1, bp);
4538 }
4539 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4540 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4541
4542 if (direction_type == 2)
4543 {
4544 wuptr += bias_data_size_g * 6;
4545 wrptr += bias_data_size_g * 6;
4546 wnptr += bias_data_size_g * 6;
4547 buptr += bias_data_size_g * 6;
4548 brptr += bias_data_size_g * 6;
4549 bnptr += bias_data_size_g * 6;
4550
4551 for (int j = 0; j < bias_data_size_g; j++)
4552 {
4553 float vb = wrptr[j] + brptr[j];
4554 fwrite(&vb, sizeof(float), 1, bp);
4555 }
4556 for (int j = 0; j < bias_data_size_g; j++)
4557 {
4558 float vb = wuptr[j] + buptr[j];
4559 fwrite(&vb, sizeof(float), 1, bp);
4560 }
4561 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4562 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4563 }
4564 }
4565
4566 // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4567 {
4568 fwrite(&quantize_tag, sizeof(int), 1, bp);
4569
4570 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4571 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4572
4573 const float* uptr = Rptr;
4574 const float* rptr = Rptr + weight_data_size_g;
4575 const float* nptr = Rptr + weight_data_size_g * 2;
4576 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4577 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4578 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4579
4580 if (direction_type == 2)
4581 {
4582 uptr += weight_data_size_g * 3;
4583 rptr += weight_data_size_g * 3;
4584 nptr += weight_data_size_g * 3;
4585 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4586 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4587 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4588 }
4589 }
4590 }
4591 else if (op == "HardSigmoid")
4592 {
4593 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4594 float beta = get_node_attr_f(node, "beta", 0.5f);
4595
4596 fprintf(pp, " 0=%e", alpha);
4597 fprintf(pp, " 1=%e", beta);
4598 }
4599 else if (op == "HardSwish")
4600 {
4601 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4602 float beta = get_node_attr_f(node, "beta", 0.5f);
4603
4604 fprintf(pp, " 0=%e", alpha);
4605 fprintf(pp, " 1=%e", beta);
4606 }
4607 else if (op == "ImageScaler")
4608 {
4609 std::vector<float> bias = get_node_attr_af(node, "bias");
4610 float scale = get_node_attr_f(node, "scale", 1.f);
4611
4612 int channels = (int)bias.size();
4613
4614 fprintf(pp, " 0=%d", channels);
4615 fprintf(pp, " 1=1");
4616
4617 for (int j = 0; j < channels; j++)
4618 {
4619 fwrite(&scale, sizeof(float), 1, bp);
4620 }
4621 fwrite(&bias[0], sizeof(float), channels, bp);
4622 }
4623 else if (op == "InstanceNormalization")
4624 {
4625 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4626
4627 // discard affine-less S=1 B=0
4628 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4629 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4630 int channels = (int)affine_S.size();
4631 int affine = 0;
4632 {
4633 for (int j = 0; j < channels; j++)
4634 {
4635 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4636 {
4637 affine = 1;
4638 break;
4639 }
4640 }
4641 }
4642
4643 fprintf(pp, " 0=%d", channels);
4644 fprintf(pp, " 1=%e", eps);
4645 fprintf(pp, " 2=%d", affine);
4646 if (affine)
4647 {
4648 const onnx::TensorProto& scale = weights[node.input(1)];
4649 const onnx::TensorProto& B = weights[node.input(2)];
4650
4651 fwrite_tensor_proto_data(scale, bp);
4652 fwrite_tensor_proto_data(B, bp);
4653 }
4654 }
4655 else if (op == "LayerNorm")
4656 {
4657 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4658 int affine = get_node_attr_i(node, "affine", 1);
4659
4660 if (affine)
4661 {
4662 // discard affine-less S=1 B=0
4663 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4664 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4665 int affine_size = (int)affine_S.size();
4666 affine = 0;
4667 {
4668 for (int j = 0; j < affine_size; j++)
4669 {
4670 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4671 {
4672 affine = 1;
4673 break;
4674 }
4675 }
4676 }
4677
4678 if (affine)
4679 {
4680 fprintf(pp, " 0=%d", affine_size);
4681 }
4682 }
4683
4684 fprintf(pp, " 1=%e", eps);
4685 fprintf(pp, " 2=%d", affine);
4686
4687 if (affine)
4688 {
4689 const onnx::TensorProto& scale = weights[node.input(1)];
4690 const onnx::TensorProto& B = weights[node.input(2)];
4691
4692 fwrite_tensor_proto_data(scale, bp);
4693 fwrite_tensor_proto_data(B, bp);
4694 }
4695 }
4696 else if (op == "LeakyRelu")
4697 {
4698 float alpha = get_node_attr_f(node, "alpha", 0.01f);
4699
4700 fprintf(pp, " 0=%e", alpha);
4701 }
4702 else if (op == "Log")
4703 {
4704 int op_type = 8;
4705 fprintf(pp, " 0=%d", op_type);
4706 }
4707 else if (op == "LRN")
4708 {
4709 float alpha = get_node_attr_f(node, "alpha", 1.f);
4710 float beta = get_node_attr_f(node, "beta", 0.5f);
4711 float bias = get_node_attr_f(node, "bias", 1.f);
4712 int size = get_node_attr_i(node, "size", 1);
4713
4714 int norm_region = 0;
4715
4716 fprintf(pp, " 0=%d", norm_region);
4717 fprintf(pp, " 1=%d", size);
4718 fprintf(pp, " 2=%e", alpha);
4719 fprintf(pp, " 3=%e", beta);
4720 fprintf(pp, " 4=%e", bias);
4721 }
4722 else if (op == "LSTM")
4723 {
4724 const onnx::TensorProto& W = weights[node.input(1)];
4725 const onnx::TensorProto& R = weights[node.input(2)];
4726 const onnx::TensorProto& B = weights[node.input(3)];
4727
4728 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4729 std::string direction = get_node_attr_s(node, "direction");
4730
4731 int direction_type = 0;
4732 if (direction == "forward")
4733 {
4734 direction_type = 0;
4735 }
4736 else if (direction == "reverse")
4737 {
4738 direction_type = 1;
4739 }
4740 else if (direction == "bidirectional")
4741 {
4742 direction_type = 2;
4743 }
4744
4745 int weight_data_size = get_tensor_proto_data_size(W);
4746
4747 fprintf(pp, " 0=%d", hidden_size);
4748 fprintf(pp, " 1=%d", weight_data_size);
4749 fprintf(pp, " 2=%d", direction_type);
4750
4751 int num_directions = direction_type == 2 ? 2 : 1;
4752
4753 int quantize_tag = 0;
4754
4755 // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4756 {
4757 fwrite(&quantize_tag, sizeof(int), 1, bp);
4758
4759 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4760 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4761
4762 const float* iptr = wptr;
4763 const float* optr = wptr + weight_data_size_g;
4764 const float* fptr = wptr + weight_data_size_g * 2;
4765 const float* gptr = wptr + weight_data_size_g * 3;
4766 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4767 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4768 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4769 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4770
4771 if (direction_type == 2)
4772 {
4773 iptr += weight_data_size_g * 4;
4774 optr += weight_data_size_g * 4;
4775 fptr += weight_data_size_g * 4;
4776 gptr += weight_data_size_g * 4;
4777 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4778 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4779 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4780 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4781 }
4782 }
4783
4784 // reduce xc and hc bias
4785 // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4786 {
4787 fwrite(&quantize_tag, sizeof(int), 1, bp);
4788
4789 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4790 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4791 const float* xiptr = xcbptr;
4792 const float* xoptr = xcbptr + bias_data_size_g;
4793 const float* xfptr = xcbptr + bias_data_size_g * 2;
4794 const float* xgptr = xcbptr + bias_data_size_g * 3;
4795 const float* hiptr = xcbptr + bias_data_size_g * 4;
4796 const float* hoptr = xcbptr + bias_data_size_g * 5;
4797 const float* hfptr = xcbptr + bias_data_size_g * 6;
4798 const float* hgptr = xcbptr + bias_data_size_g * 7;
4799
4800 for (int j = 0; j < bias_data_size_g; j++)
4801 {
4802 float vb = xiptr[j] + hiptr[j];
4803 fwrite(&vb, sizeof(float), 1, bp);
4804 }
4805 for (int j = 0; j < bias_data_size_g; j++)
4806 {
4807 float vb = xfptr[j] + hfptr[j];
4808 fwrite(&vb, sizeof(float), 1, bp);
4809 }
4810 for (int j = 0; j < bias_data_size_g; j++)
4811 {
4812 float vb = xoptr[j] + hoptr[j];
4813 fwrite(&vb, sizeof(float), 1, bp);
4814 }
4815 for (int j = 0; j < bias_data_size_g; j++)
4816 {
4817 float vb = xgptr[j] + hgptr[j];
4818 fwrite(&vb, sizeof(float), 1, bp);
4819 }
4820
4821 if (direction_type == 2)
4822 {
4823 xiptr += bias_data_size_g * 8;
4824 xoptr += bias_data_size_g * 8;
4825 xfptr += bias_data_size_g * 8;
4826 xgptr += bias_data_size_g * 8;
4827 hiptr += bias_data_size_g * 8;
4828 hoptr += bias_data_size_g * 8;
4829 hfptr += bias_data_size_g * 8;
4830 hgptr += bias_data_size_g * 8;
4831
4832 for (int j = 0; j < bias_data_size_g; j++)
4833 {
4834 float vb = xiptr[j] + hiptr[j];
4835 fwrite(&vb, sizeof(float), 1, bp);
4836 }
4837 for (int j = 0; j < bias_data_size_g; j++)
4838 {
4839 float vb = xfptr[j] + hfptr[j];
4840 fwrite(&vb, sizeof(float), 1, bp);
4841 }
4842 for (int j = 0; j < bias_data_size_g; j++)
4843 {
4844 float vb = xoptr[j] + hoptr[j];
4845 fwrite(&vb, sizeof(float), 1, bp);
4846 }
4847 for (int j = 0; j < bias_data_size_g; j++)
4848 {
4849 float vb = xgptr[j] + hgptr[j];
4850 fwrite(&vb, sizeof(float), 1, bp);
4851 }
4852 }
4853 }
4854
4855 // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4856 {
4857 fwrite(&quantize_tag, sizeof(int), 1, bp);
4858
4859 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4860 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4861
4862 const float* iptr = rptr;
4863 const float* optr = rptr + weight_data_size_g;
4864 const float* fptr = rptr + weight_data_size_g * 2;
4865 const float* gptr = rptr + weight_data_size_g * 3;
4866 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4867 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4868 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4869 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4870
4871 if (direction_type == 2)
4872 {
4873 iptr += weight_data_size_g * 4;
4874 optr += weight_data_size_g * 4;
4875 fptr += weight_data_size_g * 4;
4876 gptr += weight_data_size_g * 4;
4877 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4878 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4879 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4880 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4881 }
4882 }
4883 }
4884 else if (op == "MatMul")
4885 {
4886 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4887 {
4888 // InnerProduct
4889 const onnx::TensorProto& B = weights[node.input(1)];
4890
4891 int weight_data_size = get_tensor_proto_data_size(B);
4892
4893 int num_output = B.dims(B.dims_size() - 1);
4894 int num_input = weight_data_size / num_output;
4895
4896 fprintf(pp, " 0=%d", num_output);
4897 fprintf(pp, " 1=0");
4898 fprintf(pp, " 2=%d", weight_data_size);
4899
4900 int quantize_tag = 0;
4901 fwrite(&quantize_tag, sizeof(int), 1, bp);
4902
4903 // reorder num_input-num_output to num_output-num_input
4904 {
4905 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4906
4907 for (int j = 0; j < num_output; j++)
4908 {
4909 for (int k = 0; k < num_input; k++)
4910 {
4911 float vb = bptr[k * num_output + j];
4912 fwrite(&vb, sizeof(float), 1, bp);
4913 }
4914 }
4915 }
4916
4917 // fwrite_tensor_proto_data(B, bp)
4918 }
4919 else
4920 {
4921 // default matrix multiplication
4922 }
4923 }
4924 else if (op == "Max")
4925 {
4926 int op_type = 4;
4927 fprintf(pp, " 0=%d", op_type);
4928
4929 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4930 float b = get_node_attr_f(node, "b", 0.f);
4931 if (with_scalar)
4932 {
4933 fprintf(pp, " 1=%d", with_scalar);
4934 fprintf(pp, " 2=%e", b);
4935 }
4936 }
4937 else if (op == "Min")
4938 {
4939 int op_type = 5;
4940 fprintf(pp, " 0=%d", op_type);
4941
4942 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4943 float b = get_node_attr_f(node, "b", 0.f);
4944 if (with_scalar)
4945 {
4946 fprintf(pp, " 1=%d", with_scalar);
4947 fprintf(pp, " 2=%e", b);
4948 }
4949 }
4950 else if (op == "Mul")
4951 {
4952 int op_type = 2;
4953 fprintf(pp, " 0=%d", op_type);
4954
4955 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4956 float b = get_node_attr_f(node, "b", 0.f);
4957 if (with_scalar)
4958 {
4959 fprintf(pp, " 1=%d", with_scalar);
4960 fprintf(pp, " 2=%e", b);
4961 }
4962 }
4963 else if (op == "MultiHeadAttention")
4964 {
4965 int embed_dim = get_node_attr_i(node, "embed_dim", 0);
4966 int num_heads = get_node_attr_i(node, "num_heads", 0);
4967
4968 fprintf(pp, " 0=%d", embed_dim);
4969 fprintf(pp, " 1=%d", num_heads);
4970
4971 if (node.input_size() == 5)
4972 {
4973 const onnx::TensorProto& qkvw = weights[node.input(1)];
4974 const onnx::TensorProto& qkvb = weights[node.input(2)];
4975 const onnx::TensorProto& ow = weights[node.input(3)];
4976 const onnx::TensorProto& ob = weights[node.input(4)];
4977
4978 int weight_data_size = get_tensor_proto_data_size(ow);
4979
4980 fprintf(pp, " 2=%d", weight_data_size);
4981
4982 int quantize_tag = 0;
4983
4984 fwrite(&quantize_tag, sizeof(int), 1, bp);
4985 // transpose qw
4986 {
4987 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
4988 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
4989
4990 for (int j = 0; j < embed_dim; j++)
4991 {
4992 for (int k = 0; k < embed_dim; k++)
4993 {
4994 float vb = wptr[k * embed_dim * 3 + j];
4995 fwrite(&vb, sizeof(float), 1, bp);
4996 }
4997 }
4998
4999 fwrite(bptr, sizeof(float), embed_dim, bp);
5000 }
5001
5002 fwrite(&quantize_tag, sizeof(int), 1, bp);
5003 // transpose kw
5004 {
5005 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5006 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5007 bptr += embed_dim;
5008
5009 for (int j = 0; j < embed_dim; j++)
5010 {
5011 for (int k = 0; k < embed_dim; k++)
5012 {
5013 float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5014 fwrite(&vb, sizeof(float), 1, bp);
5015 }
5016 }
5017
5018 fwrite(bptr, sizeof(float), embed_dim, bp);
5019 }
5020
5021 fwrite(&quantize_tag, sizeof(int), 1, bp);
5022 // transpose vw
5023 {
5024 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5025 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5026 bptr += embed_dim * 2;
5027
5028 for (int j = 0; j < embed_dim; j++)
5029 {
5030 for (int k = 0; k < embed_dim; k++)
5031 {
5032 float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5033 fwrite(&vb, sizeof(float), 1, bp);
5034 }
5035 }
5036
5037 fwrite(bptr, sizeof(float), embed_dim, bp);
5038 }
5039
5040 fwrite(&quantize_tag, sizeof(int), 1, bp);
5041 // transpose ow
5042 {
5043 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5044
5045 for (int j = 0; j < embed_dim; j++)
5046 {
5047 for (int k = 0; k < embed_dim; k++)
5048 {
5049 float vb = wptr[k * embed_dim + j];
5050 fwrite(&vb, sizeof(float), 1, bp);
5051 }
5052 }
5053 }
5054 fwrite_tensor_proto_data(ob, bp);
5055 }
5056 else
5057 {
5058 const onnx::TensorProto& qw = weights[node.input(3)];
5059 const onnx::TensorProto& qb = weights[node.input(4)];
5060 const onnx::TensorProto& kw = weights[node.input(5)];
5061 const onnx::TensorProto& kb = weights[node.input(6)];
5062 const onnx::TensorProto& vw = weights[node.input(7)];
5063 const onnx::TensorProto& vb = weights[node.input(8)];
5064 const onnx::TensorProto& ow = weights[node.input(9)];
5065 const onnx::TensorProto& ob = weights[node.input(10)];
5066
5067 int weight_data_size = get_tensor_proto_data_size(qw);
5068
5069 fprintf(pp, " 2=%d", weight_data_size);
5070
5071 int quantize_tag = 0;
5072
5073 fwrite(&quantize_tag, sizeof(int), 1, bp);
5074 // transpose qw
5075 {
5076 const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5077
5078 for (int j = 0; j < embed_dim; j++)
5079 {
5080 for (int k = 0; k < embed_dim; k++)
5081 {
5082 float vb = wptr[k * embed_dim + j];
5083 fwrite(&vb, sizeof(float), 1, bp);
5084 }
5085 }
5086 }
5087 fwrite_tensor_proto_data(qb, bp);
5088
5089 fwrite(&quantize_tag, sizeof(int), 1, bp);
5090 // transpose kw
5091 {
5092 const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5093
5094 for (int j = 0; j < embed_dim; j++)
5095 {
5096 for (int k = 0; k < embed_dim; k++)
5097 {
5098 float vb = wptr[k * embed_dim + j];
5099 fwrite(&vb, sizeof(float), 1, bp);
5100 }
5101 }
5102 }
5103 fwrite_tensor_proto_data(kb, bp);
5104
5105 fwrite(&quantize_tag, sizeof(int), 1, bp);
5106 // transpose vw
5107 {
5108 const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5109
5110 for (int j = 0; j < embed_dim; j++)
5111 {
5112 for (int k = 0; k < embed_dim; k++)
5113 {
5114 float vb = wptr[k * embed_dim + j];
5115 fwrite(&vb, sizeof(float), 1, bp);
5116 }
5117 }
5118 }
5119 fwrite_tensor_proto_data(vb, bp);
5120
5121 fwrite(&quantize_tag, sizeof(int), 1, bp);
5122 // transpose ow
5123 {
5124 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5125
5126 for (int j = 0; j < embed_dim; j++)
5127 {
5128 for (int k = 0; k < embed_dim; k++)
5129 {
5130 float vb = wptr[k * embed_dim + j];
5131 fwrite(&vb, sizeof(float), 1, bp);
5132 }
5133 }
5134 }
5135 fwrite_tensor_proto_data(ob, bp);
5136 }
5137 }
5138 else if (op == "Neg")
5139 {
5140 int op_type = 1;
5141 fprintf(pp, " 0=%d", op_type);
5142 }
5143 else if (op == "Normalize")
5144 {
5145 float eps = get_node_attr_f(node, "eps", 0.f);
5146 int scale_data_size = 1;
5147
5148 fprintf(pp, " 1=1"); // channel_shared
5149 fprintf(pp, " 2=%e", eps);
5150 fprintf(pp, " 3=%d", scale_data_size);
5151 fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5152
5153 const float scale_data[1] = {1.f};
5154 fwrite(scale_data, sizeof(float), 1, bp);
5155 }
5156 else if (op == "Pad")
5157 {
5158 std::string mode = get_node_attr_s(node, "mode");
5159 float value = get_node_attr_f(node, "value", 0.f);
5160
5161 std::vector<int> pads;
5162 if (node.input_size() == 1)
5163 {
5164 pads = get_node_attr_ai(node, "pads");
5165 }
5166 else
5167 {
5168 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5169 }
5170
5171 int type = 0;
5172 if (mode == "constant")
5173 {
5174 type = 0;
5175 }
5176 else if (mode == "edge")
5177 {
5178 type = 1;
5179 }
5180 else if (mode == "reflect")
5181 {
5182 type = 2;
5183 }
5184
5185 int pad_size = (int)pads.size();
5186 int top = 0;
5187 int bottom = 0;
5188 int left = 0;
5189 int right = 0;
5190 int front = 0;
5191 int behind = 0;
5192 if (pad_size == 8)
5193 {
5194 //NCHW
5195 top = pads[2];
5196 bottom = pads[6];
5197 left = pads[3];
5198 right = pads[7];
5199 front = pads[1];
5200 behind = pads[5];
5201 }
5202 else if (pad_size == 6)
5203 {
5204 //NHW
5205 top = pads[1];
5206 bottom = pads[4];
5207 left = pads[2];
5208 right = pads[5];
5209 }
5210 else
5211 {
5212 //NW
5213 left = pads[1];
5214 right = pads[3];
5215 }
5216
5217 fprintf(pp, " 0=%d", top);
5218 fprintf(pp, " 1=%d", bottom);
5219 fprintf(pp, " 2=%d", left);
5220 fprintf(pp, " 3=%d", right);
5221 fprintf(pp, " 4=%d", type);
5222 fprintf(pp, " 5=%e", value);
5223 fprintf(pp, " 7=%d", front);
5224 fprintf(pp, " 8=%d", behind);
5225 }
5226 else if (op == "Pow")
5227 {
5228 int op_type = 6;
5229 fprintf(pp, " 0=%d", op_type);
5230
5231 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5232 float b = get_node_attr_f(node, "b", 0.f);
5233 if (with_scalar)
5234 {
5235 fprintf(pp, " 1=%d", with_scalar);
5236 fprintf(pp, " 2=%e", b);
5237 }
5238 }
5239 else if (op == "PixelShuffle")
5240 {
5241 int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5242 fprintf(pp, " 0=%d", scale_factor);
5243 }
5244 else if (op == "PRelu")
5245 {
5246 const onnx::TensorProto& slope = weights[node.input(1)];
5247
5248 int num_slope = get_tensor_proto_data_size(slope);
5249
5250 fprintf(pp, " 0=%d", num_slope);
5251
5252 fwrite_tensor_proto_data(slope, bp);
5253 }
5254 else if (op == "Reciprocal")
5255 {
5256 int op_type = 15;
5257 fprintf(pp, " 0=%d", op_type);
5258 }
5259 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5260 {
5261 int op_type = -233;
5262 if (op == "ReduceSum")
5263 op_type = 0;
5264 else if (op == "ReduceSumSquare")
5265 op_type = 2;
5266 else if (op == "ReduceMean")
5267 op_type = 3;
5268 else if (op == "ReduceMax")
5269 op_type = 4;
5270 else if (op == "ReduceMin")
5271 op_type = 5;
5272 else if (op == "ReduceProd")
5273 op_type = 6;
5274 else if (op == "ReduceL1")
5275 op_type = 7;
5276 else if (op == "ReduceL2")
5277 op_type = 8;
5278 else if (op == "ReduceLogSum")
5279 op_type = 9;
5280 else if (op == "ReduceLogSumExp")
5281 op_type = 10;
5282 fprintf(pp, " 0=%d", op_type);
5283
5284 std::vector<int> axes = get_node_attr_ai(node, "axes");
5285 int keepdims = get_node_attr_i(node, "keepdims", 1);
5286
5287 if (axes.size() > 0)
5288 {
5289 // if axes set, reduce according to axes
5290 fprintf(pp, " 1=%d", 0);
5291 fprintf(pp, " -23303=%zu", axes.size());
5292 for (size_t j = 0; j < axes.size(); j++)
5293 {
5294 if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5295 fprintf(stderr, "Unsupported reduction axes !\n");
5296 fprintf(pp, ",%d", axes[j]);
5297 }
5298 }
5299 else
5300 {
5301 // if axes not set, reduce all axes by default
5302 fprintf(pp, " 1=%d", 1);
5303 }
5304 fprintf(pp, " 4=%d", keepdims);
5305 }
5306 else if (op == "Reorg")
5307 {
5308 int stride = get_node_attr_i(node, "stride", 1);
5309 fprintf(pp, " 0=%d", stride);
5310 }
5311 else if (op == "Reshape")
5312 {
5313 std::vector<int> shape;
5314
5315 if (node.input_size() == 1)
5316 {
5317 shape = get_node_attr_ai(node, "shape");
5318 }
5319 else
5320 {
5321 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5322 }
5323
5324 if (shape.size() == 1)
5325 {
5326 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5327 }
5328 else if (shape.size() == 2)
5329 {
5330 fprintf(pp, " 0=%d", shape[1]);
5331 }
5332 else if (shape.size() == 3)
5333 {
5334 fprintf(pp, " 0=%d", shape[2]);
5335 fprintf(pp, " 1=%d", shape[1]);
5336 }
5337 else if (shape.size() == 4)
5338 {
5339 fprintf(pp, " 0=%d", shape[3]);
5340 fprintf(pp, " 1=%d", shape[2]);
5341 fprintf(pp, " 2=%d", shape[1]);
5342 }
5343 else if (shape.size() == 5)
5344 {
5345 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5346 fprintf(pp, " 1=%d", shape[2]);
5347 fprintf(pp, " 2=%d", shape[1]);
5348 }
5349 }
5350 else if (op == "Resize")
5351 {
5352 std::string mode = get_node_attr_s(node, "mode");
5353 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5354
5355 std::vector<float> scales;
5356 std::vector<int> sizes;
5357 if (node.input_size() == 2)
5358 {
5359 // opset 10
5360 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5361 }
5362 else
5363 {
5364 // opset 11+
5365 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5366 if (node.input_size() >= 4)
5367 {
5368 sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5369 }
5370 }
5371
5372 int resize_type = 1;
5373 if (mode == "nearest")
5374 {
5375 resize_type = 1;
5376 }
5377 else if (mode == "linear")
5378 {
5379 resize_type = 2;
5380 }
5381 else if (mode == "cubic")
5382 {
5383 resize_type = 3;
5384 }
5385
5386 if (scales.empty() && sizes.empty())
5387 {
5388 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5389 }
5390
5391 float h_scale = 1.f;
5392 float w_scale = 1.f;
5393 if (scales.size() == 2)
5394 {
5395 w_scale = scales[1];
5396 }
5397 else if (scales.size() == 3)
5398 {
5399 h_scale = scales[1];
5400 w_scale = scales[2];
5401 }
5402 else if (scales.size() == 4)
5403 {
5404 h_scale = scales[2];
5405 w_scale = scales[3];
5406
5407 if (scales[1] != 1.f)
5408 fprintf(stderr, "Unsupported Resize scales !\n");
5409 }
5410
5411 int output_height = 0;
5412 int output_width = 0;
5413 if (sizes.size() == 2)
5414 {
5415 output_width = sizes[1];
5416 }
5417 else if (sizes.size() == 3)
5418 {
5419 output_height = sizes[1];
5420 output_width = sizes[2];
5421 }
5422 else if (sizes.size() == 4)
5423 {
5424 output_height = sizes[2];
5425 output_width = sizes[3];
5426 }
5427
5428 int align_corner = 0;
5429 if (align == "align_corners")
5430 {
5431 align_corner = 1;
5432 }
5433
5434 fprintf(pp, " 0=%d", resize_type);
5435 fprintf(pp, " 1=%e", h_scale);
5436 fprintf(pp, " 2=%e", w_scale);
5437 fprintf(pp, " 3=%d", output_height);
5438 fprintf(pp, " 4=%d", output_width);
5439 fprintf(pp, " 6=%d", align_corner);
5440 }
5441 else if (op == "RNN")
5442 {
5443 const onnx::TensorProto& W = weights[node.input(1)];
5444 const onnx::TensorProto& R = weights[node.input(2)];
5445 const onnx::TensorProto& B = weights[node.input(3)];
5446
5447 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5448 std::string direction = get_node_attr_s(node, "direction");
5449
5450 int direction_type = 0;
5451 if (direction == "forward")
5452 {
5453 direction_type = 0;
5454 }
5455 else if (direction == "reverse")
5456 {
5457 direction_type = 1;
5458 }
5459 else if (direction == "bidirectional")
5460 {
5461 direction_type = 2;
5462 }
5463
5464 int weight_data_size = get_tensor_proto_data_size(W);
5465
5466 fprintf(pp, " 0=%d", hidden_size);
5467 fprintf(pp, " 1=%d", weight_data_size);
5468 fprintf(pp, " 2=%d", direction_type);
5469
5470 int num_directions = direction_type == 2 ? 2 : 1;
5471
5472 int quantize_tag = 0;
5473
5474 fwrite(&quantize_tag, sizeof(int), 1, bp);
5475 fwrite_tensor_proto_data(W, bp);
5476
5477 // reduce xc and hc bias
5478 {
5479 fwrite(&quantize_tag, sizeof(int), 1, bp);
5480
5481 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5482 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5483 const float* xiptr = bptr;
5484 const float* hiptr = bptr + bias_data_size_g;
5485
5486 for (int j = 0; j < bias_data_size_g; j++)
5487 {
5488 float vb = xiptr[j] + hiptr[j];
5489 fwrite(&vb, sizeof(float), 1, bp);
5490 }
5491
5492 if (direction_type == 2)
5493 {
5494 xiptr += bias_data_size_g * 2;
5495 hiptr += bias_data_size_g * 2;
5496
5497 for (int j = 0; j < bias_data_size_g; j++)
5498 {
5499 float vb = xiptr[j] + hiptr[j];
5500 fwrite(&vb, sizeof(float), 1, bp);
5501 }
5502 }
5503 }
5504
5505 fwrite(&quantize_tag, sizeof(int), 1, bp);
5506 fwrite_tensor_proto_data(R, bp);
5507 }
5508 else if (op == "ShuffleChannel")
5509 {
5510 int group = get_node_attr_i(node, "group", 1);
5511 int reverse = get_node_attr_i(node, "reverse", 0);
5512 fprintf(pp, " 0=%d", group);
5513 fprintf(pp, " 1=%d", reverse);
5514 }
5515 else if (op == "Sigmoid")
5516 {
5517 // no param
5518 }
5519 else if (op == "Sin")
5520 {
5521 int op_type = 9;
5522 fprintf(pp, " 0=%d", op_type);
5523 }
5524 else if (op == "SkipLayerNormalization")
5525 {
5526 const onnx::TensorProto& W = weights[node.input(2)];
5527 const onnx::TensorProto& B = weights[node.input(3)];
5528 const onnx::TensorProto& B2 = weights[node.input(4)];
5529
5530 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5531
5532 int quantize_tag = 0;
5533 fwrite(&quantize_tag, sizeof(int), 1, bp);
5534
5535 fwrite_tensor_proto_data(W, bp);
5536
5537 fwrite(&quantize_tag, sizeof(int), 1, bp);
5538
5539 fwrite_tensor_proto_data(B, bp);
5540
5541 fwrite(&quantize_tag, sizeof(int), 1, bp);
5542
5543 fwrite_tensor_proto_data(B2, bp);
5544 }
5545 else if (op == "Slice")
5546 {
5547 std::vector<int> starts;
5548 std::vector<int> ends;
5549 std::vector<int> axes;
5550 std::vector<int> steps;
5551 if (node.input_size() == 1)
5552 {
5553 starts = get_node_attr_ai(node, "starts");
5554 ends = get_node_attr_ai(node, "ends");
5555 axes = get_node_attr_ai(node, "axes");
5556 steps = get_node_attr_ai(node, "steps"); // TODO
5557 }
5558 else
5559 {
5560 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5561 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5562 if (node.input_size() >= 4)
5563 axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5564 if (node.input_size() >= 5)
5565 steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5566 }
5567
5568 // assert step == 1
5569 for (int i = 0; i < (int)steps.size(); i++)
5570 {
5571 if (steps[i] != 1)
5572 fprintf(stderr, "Unsupported slice step !\n");
5573 }
5574
5575 // filter out N-dim axis
5576 if (!axes.empty())
5577 {
5578 for (int i = 0; i < (int)axes.size(); i++)
5579 {
5580 int axis = axes[i];
5581 if (axis == 0)
5582 {
5583 starts.erase(starts.begin() + i);
5584 ends.erase(ends.begin() + i);
5585 axes.erase(axes.begin() + i);
5586 break;
5587 }
5588 }
5589 }
5590
5591 fprintf(pp, " -23309=%d", (int)starts.size());
5592 for (int i = 0; i < (int)starts.size(); i++)
5593 {
5594 fprintf(pp, ",%d", starts[i]);
5595 }
5596 fprintf(pp, " -23310=%d", (int)ends.size());
5597 for (int i = 0; i < (int)ends.size(); i++)
5598 {
5599 fprintf(pp, ",%d", ends[i]);
5600 }
5601 if (!axes.empty())
5602 {
5603 fprintf(pp, " -23311=%d", (int)axes.size());
5604 for (int i = 0; i < (int)axes.size(); i++)
5605 {
5606 int axis = axes[i];
5607 if (axis == 0 || axis > 3 || axis < -3)
5608 fprintf(stderr, "Unsupported slice axes !\n");
5609
5610 if (axis > 0)
5611 axis = axis - 1; // -1 for skip N-dim
5612
5613 fprintf(pp, ",%d", axis);
5614 }
5615 }
5616 }
5617 else if (op == "Softmax")
5618 {
5619 int axis = get_node_attr_i(node, "axis", 1);
5620 fprintf(pp, " 0=%d", axis - 1);
5621 fprintf(pp, " 1=1");
5622 }
5623 else if (op == "Split")
5624 {
5625 int axis = get_node_attr_i(node, "axis", 0);
5626 std::vector<int> split = get_node_attr_ai(node, "split");
5627 if (axis < 1)
5628 fprintf(stderr, "Unsupported split axis !\n");
5629
5630 fprintf(pp, " -23300=%d", output_size);
5631 if (split.empty())
5632 {
5633 for (int i = 0; i < output_size; i++)
5634 {
5635 fprintf(pp, ",-233");
5636 }
5637 }
5638 else
5639 {
5640 for (size_t i = 0; i < split.size() - 1; i++)
5641 {
5642 fprintf(pp, ",%d", split[i]);
5643 }
5644 fprintf(pp, ",-233");
5645 }
5646 fprintf(pp, " 1=%d", axis - 1);
5647 }
5648 else if (op == "Sqrt")
5649 {
5650 int op_type = 5;
5651 fprintf(pp, " 0=%d", op_type);
5652 }
5653 else if (op == "Squeeze")
5654 {
5655 std::vector<int> axes = get_node_attr_ai(node, "axes");
5656
5657 if (axes.empty())
5658 {
5659 fprintf(pp, " 0=1");
5660 fprintf(pp, " 1=1");
5661 fprintf(pp, " 2=1");
5662 }
5663 else
5664 {
5665 fprintf(pp, " -23303=%zu", axes.size());
5666 for (int i = 0; i < (int)axes.size(); i++)
5667 {
5668 if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5669 fprintf(stderr, "Unsupported squeeze axes !\n");
5670 fprintf(pp, ",%d", axes[i]);
5671 }
5672 }
5673 }
5674 else if (op == "Sub")
5675 {
5676 int op_type = 1;
5677 fprintf(pp, " 0=%d", op_type);
5678
5679 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5680 float b = get_node_attr_f(node, "b", 0.f);
5681 if (with_scalar)
5682 {
5683 fprintf(pp, " 1=%d", with_scalar);
5684 fprintf(pp, " 2=%e", b);
5685 }
5686 }
5687 else if (op == "Sum")
5688 {
5689 int op_type = 1;
5690 fprintf(pp, " 0=%d", op_type);
5691 }
5692 else if (op == "Swish")
5693 {
5694 // no param
5695 }
5696 else if (op == "Tan")
5697 {
5698 int op_type = 11;
5699 fprintf(pp, " 0=%d", op_type);
5700 }
5701 else if (op == "Tanh")
5702 {
5703 int op_type = 16;
5704 fprintf(pp, " 0=%d", op_type);
5705 }
5706 else if (op == "Transpose")
5707 {
5708 std::vector<int> perm = get_node_attr_ai(node, "perm");
5709
5710 if (perm.size() == 3)
5711 {
5712 if (perm[1] == 1 && perm[2] == 2)
5713 fprintf(pp, " 0=0"); // w h
5714 else if (perm[1] == 2 && perm[2] == 1)
5715 fprintf(pp, " 0=1"); // h w
5716 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5717 fprintf(pp, " 0=0"); // w h
5718 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5719 fprintf(pp, " 0=1"); // h w
5720 }
5721 else if (perm.size() == 4)
5722 {
5723 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5724 fprintf(pp, " 0=0"); // w h c
5725 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5726 fprintf(pp, " 0=1"); // h w c
5727 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5728 fprintf(pp, " 0=2"); // w c h
5729 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5730 fprintf(pp, " 0=3"); // c w h
5731 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5732 fprintf(pp, " 0=4"); // h c w
5733 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5734 fprintf(pp, " 0=5"); // c h w
5735 }
5736 else if (perm.size() == 5)
5737 {
5738 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5739 fprintf(pp, " 0=0"); // wx h c
5740 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5741 fprintf(pp, " 0=1"); // h wx c
5742 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5743 fprintf(pp, " 0=2"); // wx c h
5744 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5745 fprintf(pp, " 0=3"); // c wx h
5746 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5747 fprintf(pp, " 0=4"); // h c wx
5748 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5749 fprintf(pp, " 0=5"); // c h wx
5750 else
5751 fprintf(stderr, "Unsupported transpose type !\n");
5752 }
5753 }
5754 else if (op == "Upsample")
5755 {
5756 std::string mode = get_node_attr_s(node, "mode");
5757 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5758
5759 std::vector<float> scales;
5760
5761 if (node.input_size() == 1)
5762 {
5763 scales = get_node_attr_af(node, "scales");
5764 }
5765 else
5766 {
5767 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5768 }
5769
5770 int resize_type = 1;
5771 if (mode == "nearest")
5772 {
5773 resize_type = 1;
5774 }
5775 else if (mode == "bilinear" || mode == "linear")
5776 {
5777 resize_type = 2;
5778 }
5779 else if (mode == "trilinear")
5780 {
5781 fprintf(stderr, "Unsupported Upsample mode !\n");
5782 }
5783
5784 float h_scale = 1.f;
5785 float w_scale = 1.f;
5786 if (scales.size() == 2)
5787 {
5788 w_scale = scales[1];
5789 }
5790 else if (scales.size() == 3)
5791 {
5792 h_scale = scales[1];
5793 w_scale = scales[2];
5794 }
5795 else if (scales.size() == 4)
5796 {
5797 h_scale = scales[2];
5798 w_scale = scales[3];
5799
5800 if (scales[1] != 1.f)
5801 fprintf(stderr, "Unsupported Upsample scales !\n");
5802 }
5803 else
5804 {
5805 fprintf(stderr, "Unsupported Upsample scales !\n");
5806 }
5807
5808 int align_corner = 0;
5809 if (align == "align_corners")
5810 {
5811 align_corner = 1;
5812 }
5813
5814 fprintf(pp, " 0=%d", resize_type);
5815 fprintf(pp, " 1=%e", h_scale);
5816 fprintf(pp, " 2=%e", w_scale);
5817 fprintf(pp, " 6=%d", align_corner);
5818 }
5819 else if (op == "Unsqueeze")
5820 {
5821 std::vector<int> axes = get_node_attr_ai(node, "axes");
5822
5823 fprintf(pp, " -23303=%zu", axes.size());
5824 for (int i = 0; i < (int)axes.size(); i++)
5825 {
5826 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5827 fprintf(stderr, "Unsupported unsqueeze axes !\n");
5828 fprintf(pp, ",%d", axes[i]);
5829 }
5830 }
5831 else
5832 {
5833 // TODO op specific param
5834 for (int j = 0; j < node.attribute_size(); j++)
5835 {
5836 const onnx::AttributeProto& attr = node.attribute(j);
5837 if (attr.type() == 1)
5838 {
5839 fprintf(stderr, " # %s=%g\n", attr.name().c_str(), attr.f());
5840 }
5841 else if (attr.type() == 2)
5842 {
5843 fprintf(stderr, " # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5844 }
5845 else if (attr.type() == 3)
5846 {
5847 fprintf(stderr, " # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5848 }
5849 else
5850 {
5851 fprintf(stderr, " # %s %d\n", attr.name().c_str(), attr.type());
5852 }
5853 }
5854 }
5855
5856 fprintf(pp, "\n");
5857
5858 for (int j = 0; j < output_size; j++)
5859 {
5860 const std::string& output_name = node.output(j);
5861 if (node_reference.find(output_name) != node_reference.end())
5862 {
5863 int refcount = node_reference[output_name];
5864 if (refcount > 1)
5865 {
5866 char splitname[256];
5867 sprintf(splitname, "splitncnn_%d", internal_split);
5868 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5869
5870 fprintf(pp, " %s", output_name.c_str());
5871
5872 for (int k = 0; k < refcount; k++)
5873 {
5874 fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5875 }
5876 fprintf(pp, "\n");
5877
5878 internal_split++;
5879 }
5880 }
5881 }
5882 }
5883
5884 fclose(pp);
5885 fclose(bp);
5886
5887 return 0;
5888 }
5889