1 #include <unordered_map>
2 #include <vector>
3
4 #include "../../common/Common.hpp"
5 #include "MNN_generated.h"
6 #include "MergeHelpers.hpp"
7
8 using namespace MNN::Express;
9
10 namespace MNN {
11 namespace helpers {
12
IsConstant(EXPRP expr)13 bool IsConstant(EXPRP expr) {
14 const Op* op = expr->get();
15 if ((op && op->type() == OpType_Const) || (!op && expr->inputType() == VARP::CONSTANT)) {
16 return true;
17 }
18 return false;
19 }
20
IsBinaryOp(EXPRP expr)21 bool IsBinaryOp(EXPRP expr) {
22 const Op* op = expr->get();
23 return op && op->type() == OpType_BinaryOp;
24 }
25
IsUnaryOp(EXPRP expr)26 bool IsUnaryOp(EXPRP expr) {
27 const Op* op = expr->get();
28 return op && op->type() == OpType_UnaryOp;
29 }
30
31 #define IS_BINARY_OP_TYPE(op_type) \
32 if (!IsBinaryOp(expr)) { \
33 return false; \
34 } \
35 int type = expr->get()->main_as_BinaryOp()->opType(); \
36 return type == op_type;
37
38 #define IS_UNARY_OP_TYPE(op_type) \
39 if (!IsUnaryOp(expr)) { \
40 return false; \
41 } \
42 int type = expr->get()->main_as_UnaryOp()->opType(); \
43 return type == op_type;
44
IsBinaryAdd(EXPRP expr)45 bool IsBinaryAdd(EXPRP expr) {
46 IS_BINARY_OP_TYPE(BinaryOpOperation_ADD);
47 }
48
IsBinarySub(EXPRP expr)49 bool IsBinarySub(EXPRP expr) {
50 IS_BINARY_OP_TYPE(BinaryOpOperation_SUB);
51 }
52
IsBinaryMul(EXPRP expr)53 bool IsBinaryMul(EXPRP expr) {
54 IS_BINARY_OP_TYPE(BinaryOpOperation_MUL);
55 }
56
IsBinarySquaredDifference(Express::EXPRP expr)57 bool IsBinarySquaredDifference(Express::EXPRP expr) {
58 IS_BINARY_OP_TYPE(BinaryOpOperation_SquaredDifference);
59 }
60
IsUnarySquare(EXPRP expr)61 bool IsUnarySquare(EXPRP expr) {
62 IS_UNARY_OP_TYPE(UnaryOpOperation_SQUARE);
63 }
64
IsUnaryRsqrt(EXPRP expr)65 bool IsUnaryRsqrt(EXPRP expr) {
66 IS_UNARY_OP_TYPE(UnaryOpOperation_RSQRT);
67 }
68
69 #undef IS_BINARY_OP_TYPE
70 #undef IS_UNARY_OP_TYPE
71
IsReductionMean(EXPRP expr)72 bool IsReductionMean(EXPRP expr) {
73 const Op* op = expr->get();
74 if (!op || op->type() != OpType_Reduction) {
75 return false;
76 }
77 int type = op->main_as_ReductionParam()->operation();
78 return type == ReductionType_MEAN;
79 }
80
IsConvolution(EXPRP expr)81 bool IsConvolution(EXPRP expr) {
82 const Op* op = expr->get();
83 return op && op->type() == OpType_Convolution;
84 }
85
IsExpandDims(EXPRP expr)86 bool IsExpandDims(EXPRP expr) {
87 const Op* op = expr->get();
88 return op && op->type() == OpType_ExpandDims;
89 }
90
InputExpr(EXPRP expr,int input_index)91 EXPRP InputExpr(EXPRP expr, int input_index) {
92 return expr->inputs().at(input_index)->expr().first;
93 }
94
OutputExpr(EXPRP expr,int output_index)95 EXPRP OutputExpr(EXPRP expr, int output_index) {
96 return expr->outputs().at(output_index).lock();
97 }
98
OutputVars(EXPRP expr)99 std::vector<VARP> OutputVars(EXPRP expr) {
100 std::unordered_map<int, VARP> outputs;
101 for (WeakEXPRP w : expr->outputs()) {
102 EXPRP child = w.lock();
103 if (!child.get()) {
104 continue;
105 }
106 for (VARP output : child->inputs()) {
107 int output_index = 0;
108 EXPRP parent;
109 std::tie(parent, output_index) = output->expr();
110 if (parent.get() == expr.get()) {
111 outputs.emplace(output_index, output);
112 }
113 }
114 }
115 std::vector<VARP> v_outputs;
116 for (const auto& it : outputs) {
117 int index = 0;
118 VARP output;
119 std::tie(index, output) = it;
120 if (!output.get()) {
121 continue;
122 }
123 if (v_outputs.size() <= index) {
124 v_outputs.resize(index + 1);
125 }
126 v_outputs[index] = output;
127 }
128 return std::move(v_outputs);
129 }
130
ConvertLayout(VARP input,Dimensionformat dest_layout,Dimensionformat src_layout)131 VARP ConvertLayout(VARP input, Dimensionformat dest_layout, Dimensionformat src_layout) {
132 std::unique_ptr<OpT> convert(new OpT);
133 convert->type = OpType_ConvertTensor;
134 convert->main.type = OpParameter_TensorConvertInfo;
135 convert->main.value = new TensorConvertInfoT;
136 convert->main.AsTensorConvertInfo()->dest = convertFormat(dest_layout);
137 convert->main.AsTensorConvertInfo()->source = convertFormat(src_layout);
138 return (Variable::create(Expr::create(convert.get(), {input})));
139 }
140
141 } // namespace helpers
142 } // namespace MNN
143