1 //
2 // Conv1dSqueezeMove.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2021/03/05.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "../TemplateMerge.hpp"
10 #include "MNN_generated.h"
11 #include "MergeHelpers.hpp"
12 #include "cli.hpp"
13 #include "MNN_compression.pb.h"
14 #include <fstream>
15
16 namespace MNN {
17 namespace Express {
18
19 enum Conv1dPostCases {
20 None,
21 BiasAdd,
22 Relu,
23 // don't need BiasAddRelu
24 };
25
__anonb12fa25e0102(EXPRP expr) 26 auto getConv1dPostCase = [](EXPRP expr) {
27 auto noPost = Conv1dPostCases::None;
28 auto returnPost = noPost;
29
30 if (nullptr == expr->get()) {
31 return noPost;
32 }
33
34 auto opType = expr->get()->type();
35
36 auto gConverterConfig = Global<modelConfig>::Get();
37 std::string compressFileName = gConverterConfig->compressionParamsFile;
38 Compression::Pipeline proto;
39 if (compressFileName != "") {
40 std::fstream input(compressFileName.c_str(), std::ios::in | std::ios::binary);
41 if (!proto.ParseFromIstream(&input)) {
42 MNN_ERROR("Failed to parse compression pipeline proto.\n");
43 }
44 }
45
46 auto findQuantParameters = [&](Compression::Pipeline& proto, std::string outputTensorName) {
47 for (const auto& algo : proto.algo()) {
48 if (algo.type() == Compression::CompressionAlgo::QUANTIZE) {
49 auto quantParams = algo.quant_params();
50 for (const auto& layerProto : quantParams.layer()) {
51 const std::string& outputName = layerProto.output(0).name();
52 if (outputName == outputTensorName) {
53 return layerProto;
54 }
55 }
56 }
57 }
58 MNN::Compression::LayerQuantizeParams empty;
59 return empty;
60 };
61
62 EXPRP squeezeExpr = nullptr;
63
64 // BiasAdd
65 if (opType == OpType::OpType_BinaryOp) {
66 auto binaryOp = expr->get();
67 auto binaryParams = binaryOp->main_as_BinaryOp();
68 if (binaryParams->opType() != BinaryOpOperation_ADD) {
69 return noPost;
70 }
71
72 auto input0 = expr->inputs()[0];
73 auto expr0 = input0->expr().first;
74 auto input1 = expr->inputs()[1];
75 auto expr1 = input1->expr().first;
76
77 EXPRP constExpr = nullptr;
78 VARP constVar = nullptr;
79
80 if (helpers::IsConstant(expr0) && helpers::IsConstant(expr1)) {
81 return noPost;
82 }
83 if (helpers::IsConstant(expr0)) {
84 constExpr = expr0;
85 constVar = input0;
86 squeezeExpr = expr1;
87 } else if (helpers::IsConstant(expr1)) {
88 constExpr = expr1;
89 constVar = input1;
90 squeezeExpr = expr0;
91 } else {
92 return noPost;
93 }
94
95 if (constExpr->get() == nullptr) { // expr const
96 if (constVar->getInfo()->dim.size() > 1) {
97 return noPost;
98 }
99 } else { // op const
100 auto constParam = constExpr->get()->main_as_Blob();
101 if (constParam->dims()->size() > 1) {
102 return noPost;
103 }
104 }
105
106 if (!squeezeExpr->get() || squeezeExpr->get()->type() != OpType::OpType_Squeeze) {
107 return noPost;
108 }
109 auto squeezeDims = squeezeExpr->get()->main_as_SqueezeParam()->squeezeDims();
110 if (squeezeDims->size() != 1) {
111 return noPost;
112 }
113 if ((squeezeDims->data()[0] == -1) || (squeezeDims->data()[0] == 3)) {
114 return noPost;
115 }
116
117 returnPost = Conv1dPostCases::BiasAdd;
118 }
119 // relu
120 else if (opType == OpType::OpType_ReLU || opType == OpType::OpType_ReLU6) {
121 auto input = expr->inputs()[0];
122 auto inputExpr = input->expr().first;
123
124 if (!inputExpr->get() || inputExpr->get()->type() != OpType::OpType_Squeeze) {
125 return noPost;
126 }
127 squeezeExpr = inputExpr;
128
129 returnPost = Conv1dPostCases::Relu;
130 }
131 else {
132 return noPost;
133 }
134
135 if (squeezeExpr != nullptr) {
136 auto squeezeInput = squeezeExpr->inputs()[0];
137 auto squeezeInputExpr = squeezeInput->expr().first;
138 if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
139 if (compressFileName != "") {
140 auto quantParams = findQuantParameters(proto, squeezeInputExpr->outputName(0));
141 // some conv1d squeeze may not be considered
142 if (quantParams.weight_size() != 0) {
143 return noPost;
144 }
145 }
146 }
147 }
148
149 return returnPost;
150 };
151
__anonb12fa25e0302() 152 static auto gRegister = []() {
153 auto match = [](EXPRP expr) {
154 auto postCase = getConv1dPostCase(expr);
155 if (postCase != Conv1dPostCases::None) {
156 return true;
157 }
158
159 return false;
160 };
161
162 auto transform = [](EXPRP expr) {
163 auto postCase = getConv1dPostCase(expr);
164
165 if (postCase == Conv1dPostCases::BiasAdd) {
166 auto input0 = expr->inputs()[0];
167 auto expr0 = input0->expr().first;
168 auto input1 = expr->inputs()[1];
169 auto expr1 = input1->expr().first;
170
171 EXPRP constExpr = nullptr;
172 VARP constVar = nullptr;
173 EXPRP squeezeExpr = nullptr;
174 VARP squeezeInput = nullptr;
175 int constIndex = 0;
176 std::vector<VARP> newBiasAddInputs;
177
178 if (helpers::IsConstant(expr0)) {
179 constExpr = expr0;
180 constVar = input0;
181 squeezeExpr = expr1;
182 squeezeInput = expr1->inputs()[0];
183 constIndex = 0;
184 } else if (helpers::IsConstant(expr1)) {
185 constExpr = expr1;
186 constVar = input1;
187 squeezeExpr = expr0;
188 squeezeInput = expr0->inputs()[0];
189 constIndex = 1;
190 }
191
192 auto squeezeInputExpr = squeezeInput->expr().first;
193 if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
194 auto convInput = squeezeInputExpr->inputs()[0];
195 auto newConvExpr = Expr::create(squeezeInputExpr->extra(), {convInput});
196 newConvExpr->setName(squeezeInputExpr->name());
197 auto newConvOutput = Variable::create(newConvExpr, 0);
198 newConvOutput->setName(squeezeInputExpr->outputName(0));
199 squeezeInput = newConvOutput;
200 }
201
202 if (constIndex == 0) {
203 newBiasAddInputs.push_back(constVar);
204 newBiasAddInputs.push_back(squeezeInput);
205 } else {
206 newBiasAddInputs.push_back(squeezeInput);
207 newBiasAddInputs.push_back(constVar);
208 }
209
210 auto newBiasAddExpr = Expr::create(expr->extra(), std::move(newBiasAddInputs));
211 newBiasAddExpr->setName(expr->name());
212 auto newBiasAddVar = Variable::create(newBiasAddExpr, 0);
213 newBiasAddVar->setName(expr->outputName(0));
214 auto newSqueezeExpr = Expr::create(squeezeExpr->extra(), {newBiasAddVar});
215 newSqueezeExpr->setName(squeezeExpr->name());
216 auto newSqueezeVar = Variable::create(newSqueezeExpr, 0);
217 newSqueezeVar->setName(squeezeExpr->outputName(0));
218
219 Expr::replace(expr, newSqueezeExpr);
220 return true;
221 }
222
223 if (postCase == Conv1dPostCases::Relu) {
224 auto input = expr->inputs()[0];
225 auto squeezeExpr = input->expr().first;
226 auto squeezeInput = squeezeExpr->inputs()[0];
227 auto squeezeInputExpr = squeezeInput->expr().first;
228
229 if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
230 auto convInput = squeezeInputExpr->inputs()[0];
231 auto newConvExpr = Expr::create(squeezeInputExpr->extra(), {convInput});
232 newConvExpr->setName(squeezeInputExpr->name());
233 auto newConvOutput = Variable::create(newConvExpr, 0);
234 newConvOutput->setName(squeezeInputExpr->outputName(0));
235 squeezeInput = newConvOutput;
236 }
237
238 auto newReluExpr = Expr::create(expr->extra(), {squeezeInput});
239 newReluExpr->setName(expr->name());
240 auto newReluVar = Variable::create(newReluExpr, 0);
241 newReluVar->setName(expr->outputName(0));
242 auto newSqueezeExpr = Expr::create(squeezeExpr->extra(), {newReluVar});
243 newSqueezeExpr->setName(squeezeExpr->name());
244 auto newSqueezeVar = Variable::create(newSqueezeExpr, 0);
245 newSqueezeVar->setName(squeezeExpr->outputName(0));
246
247 Expr::replace(expr, newSqueezeExpr);
248 return true;
249 }
250
251 return false;
252 };
253
254 TemplateMerge::getInstance("Merge").insertTemplate("Conv1dSqueezeMove", match, transform,
255 PASS_PRIORITY_HIGH);
256 return true;
257 }();
258
259 }
260 } // namespace MNN
261