1 #include <unordered_map>
2 #include <vector>
3 
4 #include "../../common/Common.hpp"
5 #include "MNN_generated.h"
6 #include "MergeHelpers.hpp"
7 
8 using namespace MNN::Express;
9 
10 namespace MNN {
11 namespace helpers {
12 
IsConstant(EXPRP expr)13 bool IsConstant(EXPRP expr) {
14     const Op* op = expr->get();
15     if ((op && op->type() == OpType_Const) || (!op && expr->inputType() == VARP::CONSTANT)) {
16         return true;
17     }
18     return false;
19 }
20 
IsBinaryOp(EXPRP expr)21 bool IsBinaryOp(EXPRP expr) {
22     const Op* op = expr->get();
23     return op && op->type() == OpType_BinaryOp;
24 }
25 
IsUnaryOp(EXPRP expr)26 bool IsUnaryOp(EXPRP expr) {
27     const Op* op = expr->get();
28     return op && op->type() == OpType_UnaryOp;
29 }
30 
31 #define IS_BINARY_OP_TYPE(op_type)                        \
32     if (!IsBinaryOp(expr)) {                              \
33         return false;                                     \
34     }                                                     \
35     int type = expr->get()->main_as_BinaryOp()->opType(); \
36     return type == op_type;
37 
38 #define IS_UNARY_OP_TYPE(op_type)                        \
39     if (!IsUnaryOp(expr)) {                              \
40         return false;                                    \
41     }                                                    \
42     int type = expr->get()->main_as_UnaryOp()->opType(); \
43     return type == op_type;
44 
IsBinaryAdd(EXPRP expr)45 bool IsBinaryAdd(EXPRP expr) {
46     IS_BINARY_OP_TYPE(BinaryOpOperation_ADD);
47 }
48 
IsBinarySub(EXPRP expr)49 bool IsBinarySub(EXPRP expr) {
50     IS_BINARY_OP_TYPE(BinaryOpOperation_SUB);
51 }
52 
IsBinaryMul(EXPRP expr)53 bool IsBinaryMul(EXPRP expr) {
54     IS_BINARY_OP_TYPE(BinaryOpOperation_MUL);
55 }
56 
IsBinarySquaredDifference(Express::EXPRP expr)57 bool IsBinarySquaredDifference(Express::EXPRP expr) {
58     IS_BINARY_OP_TYPE(BinaryOpOperation_SquaredDifference);
59 }
60 
IsUnarySquare(EXPRP expr)61 bool IsUnarySquare(EXPRP expr) {
62     IS_UNARY_OP_TYPE(UnaryOpOperation_SQUARE);
63 }
64 
IsUnaryRsqrt(EXPRP expr)65 bool IsUnaryRsqrt(EXPRP expr) {
66     IS_UNARY_OP_TYPE(UnaryOpOperation_RSQRT);
67 }
68 
69 #undef IS_BINARY_OP_TYPE
70 #undef IS_UNARY_OP_TYPE
71 
IsReductionMean(EXPRP expr)72 bool IsReductionMean(EXPRP expr) {
73     const Op* op = expr->get();
74     if (!op || op->type() != OpType_Reduction) {
75         return false;
76     }
77     int type = op->main_as_ReductionParam()->operation();
78     return type == ReductionType_MEAN;
79 }
80 
IsConvolution(EXPRP expr)81 bool IsConvolution(EXPRP expr) {
82     const Op* op = expr->get();
83     return op && op->type() == OpType_Convolution;
84 }
85 
IsExpandDims(EXPRP expr)86 bool IsExpandDims(EXPRP expr) {
87     const Op* op = expr->get();
88     return op && op->type() == OpType_ExpandDims;
89 }
90 
InputExpr(EXPRP expr,int input_index)91 EXPRP InputExpr(EXPRP expr, int input_index) {
92     return expr->inputs().at(input_index)->expr().first;
93 }
94 
OutputExpr(EXPRP expr,int output_index)95 EXPRP OutputExpr(EXPRP expr, int output_index) {
96     return expr->outputs().at(output_index).lock();
97 }
98 
OutputVars(EXPRP expr)99 std::vector<VARP> OutputVars(EXPRP expr) {
100     std::unordered_map<int, VARP> outputs;
101     for (WeakEXPRP w : expr->outputs()) {
102         EXPRP child = w.lock();
103         if (!child.get()) {
104             continue;
105         }
106         for (VARP output : child->inputs()) {
107             int output_index = 0;
108             EXPRP parent;
109             std::tie(parent, output_index) = output->expr();
110             if (parent.get() == expr.get()) {
111                 outputs.emplace(output_index, output);
112             }
113         }
114     }
115     std::vector<VARP> v_outputs;
116     for (const auto& it : outputs) {
117         int index = 0;
118         VARP output;
119         std::tie(index, output) = it;
120         if (!output.get()) {
121             continue;
122         }
123         if (v_outputs.size() <= index) {
124             v_outputs.resize(index + 1);
125         }
126         v_outputs[index] = output;
127     }
128     return std::move(v_outputs);
129 }
130 
ConvertLayout(VARP input,Dimensionformat dest_layout,Dimensionformat src_layout)131 VARP ConvertLayout(VARP input, Dimensionformat dest_layout, Dimensionformat src_layout) {
132     std::unique_ptr<OpT> convert(new OpT);
133     convert->type                               = OpType_ConvertTensor;
134     convert->main.type                          = OpParameter_TensorConvertInfo;
135     convert->main.value                         = new TensorConvertInfoT;
136     convert->main.AsTensorConvertInfo()->dest   = convertFormat(dest_layout);
137     convert->main.AsTensorConvertInfo()->source = convertFormat(src_layout);
138     return (Variable::create(Expr::create(convert.get(), {input})));
139 }
140 
141 } // namespace helpers
142 } // namespace MNN
143