1 //
2 //  Conv1dSqueezeMove.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2021/03/05.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../TemplateMerge.hpp"
10 #include "MNN_generated.h"
11 #include "MergeHelpers.hpp"
12 #include "cli.hpp"
13 #include "MNN_compression.pb.h"
14 #include <fstream>
15 
16 namespace MNN {
17 namespace Express {
18 
19 enum Conv1dPostCases {
20     None,
21     BiasAdd,
22     Relu,
23     // don't need BiasAddRelu
24 };
25 
__anonb12fa25e0102(EXPRP expr) 26 auto getConv1dPostCase = [](EXPRP expr) {
27     auto noPost = Conv1dPostCases::None;
28     auto returnPost = noPost;
29 
30     if (nullptr == expr->get()) {
31         return noPost;
32     }
33 
34     auto opType = expr->get()->type();
35 
36     auto gConverterConfig = Global<modelConfig>::Get();
37     std::string compressFileName = gConverterConfig->compressionParamsFile;
38     Compression::Pipeline proto;
39     if (compressFileName != "") {
40         std::fstream input(compressFileName.c_str(), std::ios::in | std::ios::binary);
41         if (!proto.ParseFromIstream(&input)) {
42             MNN_ERROR("Failed to parse compression pipeline proto.\n");
43         }
44     }
45 
46     auto findQuantParameters = [&](Compression::Pipeline& proto, std::string outputTensorName) {
47         for (const auto& algo : proto.algo()) {
48             if (algo.type() == Compression::CompressionAlgo::QUANTIZE) {
49                 auto quantParams = algo.quant_params();
50                 for (const auto& layerProto : quantParams.layer()) {
51                     const std::string& outputName = layerProto.output(0).name();
52                     if (outputName == outputTensorName) {
53                         return layerProto;
54                     }
55                 }
56             }
57         }
58         MNN::Compression::LayerQuantizeParams empty;
59         return empty;
60     };
61 
62     EXPRP squeezeExpr = nullptr;
63 
64     // BiasAdd
65     if (opType == OpType::OpType_BinaryOp) {
66         auto binaryOp     = expr->get();
67         auto binaryParams = binaryOp->main_as_BinaryOp();
68         if (binaryParams->opType() != BinaryOpOperation_ADD) {
69             return noPost;
70         }
71 
72         auto input0 = expr->inputs()[0];
73         auto expr0  = input0->expr().first;
74         auto input1 = expr->inputs()[1];
75         auto expr1  = input1->expr().first;
76 
77         EXPRP constExpr = nullptr;
78         VARP constVar = nullptr;
79 
80         if (helpers::IsConstant(expr0) && helpers::IsConstant(expr1)) {
81             return noPost;
82         }
83         if (helpers::IsConstant(expr0)) {
84             constExpr = expr0;
85             constVar = input0;
86             squeezeExpr = expr1;
87         } else if (helpers::IsConstant(expr1)) {
88             constExpr = expr1;
89             constVar = input1;
90             squeezeExpr = expr0;
91         } else {
92             return noPost;
93         }
94 
95         if (constExpr->get() == nullptr) { // expr const
96             if (constVar->getInfo()->dim.size() > 1) {
97                 return noPost;
98             }
99         } else { // op const
100             auto constParam = constExpr->get()->main_as_Blob();
101             if (constParam->dims()->size() > 1) {
102                 return noPost;
103             }
104         }
105 
106         if (!squeezeExpr->get() || squeezeExpr->get()->type() != OpType::OpType_Squeeze) {
107             return noPost;
108         }
109         auto squeezeDims = squeezeExpr->get()->main_as_SqueezeParam()->squeezeDims();
110         if (squeezeDims->size() != 1) {
111             return noPost;
112         }
113         if ((squeezeDims->data()[0] == -1) || (squeezeDims->data()[0] == 3)) {
114             return noPost;
115         }
116 
117         returnPost = Conv1dPostCases::BiasAdd;
118     }
119     // relu
120     else if (opType == OpType::OpType_ReLU || opType == OpType::OpType_ReLU6) {
121         auto input = expr->inputs()[0];
122         auto inputExpr  = input->expr().first;
123 
124         if (!inputExpr->get() || inputExpr->get()->type() != OpType::OpType_Squeeze) {
125             return noPost;
126         }
127         squeezeExpr = inputExpr;
128 
129         returnPost = Conv1dPostCases::Relu;
130     }
131     else {
132         return noPost;
133     }
134 
135     if (squeezeExpr != nullptr) {
136         auto squeezeInput = squeezeExpr->inputs()[0];
137         auto squeezeInputExpr = squeezeInput->expr().first;
138         if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
139             if (compressFileName != "") {
140                 auto quantParams = findQuantParameters(proto, squeezeInputExpr->outputName(0));
141                 // some conv1d squeeze may not be considered
142                 if (quantParams.weight_size() != 0) {
143                     return noPost;
144                 }
145             }
146         }
147     }
148 
149     return returnPost;
150 };
151 
__anonb12fa25e0302() 152 static auto gRegister = []() {
153     auto match = [](EXPRP expr) {
154         auto postCase = getConv1dPostCase(expr);
155         if (postCase != Conv1dPostCases::None) {
156             return true;
157         }
158 
159         return false;
160     };
161 
162     auto transform = [](EXPRP expr) {
163         auto postCase = getConv1dPostCase(expr);
164 
165         if (postCase == Conv1dPostCases::BiasAdd) {
166             auto input0 = expr->inputs()[0];
167             auto expr0  = input0->expr().first;
168             auto input1 = expr->inputs()[1];
169             auto expr1  = input1->expr().first;
170 
171             EXPRP constExpr = nullptr;
172             VARP constVar = nullptr;
173             EXPRP squeezeExpr = nullptr;
174             VARP squeezeInput = nullptr;
175             int constIndex = 0;
176             std::vector<VARP> newBiasAddInputs;
177 
178             if (helpers::IsConstant(expr0)) {
179                 constExpr = expr0;
180                 constVar = input0;
181                 squeezeExpr = expr1;
182                 squeezeInput = expr1->inputs()[0];
183                 constIndex = 0;
184             } else if (helpers::IsConstant(expr1)) {
185                 constExpr = expr1;
186                 constVar = input1;
187                 squeezeExpr = expr0;
188                 squeezeInput = expr0->inputs()[0];
189                 constIndex = 1;
190             }
191 
192             auto squeezeInputExpr = squeezeInput->expr().first;
193             if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
194                 auto convInput = squeezeInputExpr->inputs()[0];
195                 auto newConvExpr = Expr::create(squeezeInputExpr->extra(), {convInput});
196                 newConvExpr->setName(squeezeInputExpr->name());
197                 auto newConvOutput = Variable::create(newConvExpr, 0);
198                 newConvOutput->setName(squeezeInputExpr->outputName(0));
199                 squeezeInput = newConvOutput;
200             }
201 
202             if (constIndex == 0) {
203                 newBiasAddInputs.push_back(constVar);
204                 newBiasAddInputs.push_back(squeezeInput);
205             } else {
206                 newBiasAddInputs.push_back(squeezeInput);
207                 newBiasAddInputs.push_back(constVar);
208             }
209 
210             auto newBiasAddExpr = Expr::create(expr->extra(), std::move(newBiasAddInputs));
211             newBiasAddExpr->setName(expr->name());
212             auto newBiasAddVar = Variable::create(newBiasAddExpr, 0);
213             newBiasAddVar->setName(expr->outputName(0));
214             auto newSqueezeExpr = Expr::create(squeezeExpr->extra(), {newBiasAddVar});
215             newSqueezeExpr->setName(squeezeExpr->name());
216             auto newSqueezeVar = Variable::create(newSqueezeExpr, 0);
217             newSqueezeVar->setName(squeezeExpr->outputName(0));
218 
219             Expr::replace(expr, newSqueezeExpr);
220             return true;
221         }
222 
223         if (postCase == Conv1dPostCases::Relu) {
224             auto input = expr->inputs()[0];
225             auto squeezeExpr  = input->expr().first;
226             auto squeezeInput = squeezeExpr->inputs()[0];
227             auto squeezeInputExpr = squeezeInput->expr().first;
228 
229             if (squeezeInputExpr->get() && squeezeInputExpr->get()->main_type() == OpParameter_Convolution2D && squeezeInputExpr->outputs().size() == 1) {
230                 auto convInput = squeezeInputExpr->inputs()[0];
231                 auto newConvExpr = Expr::create(squeezeInputExpr->extra(), {convInput});
232                 newConvExpr->setName(squeezeInputExpr->name());
233                 auto newConvOutput = Variable::create(newConvExpr, 0);
234                 newConvOutput->setName(squeezeInputExpr->outputName(0));
235                 squeezeInput = newConvOutput;
236             }
237 
238             auto newReluExpr = Expr::create(expr->extra(), {squeezeInput});
239             newReluExpr->setName(expr->name());
240             auto newReluVar = Variable::create(newReluExpr, 0);
241             newReluVar->setName(expr->outputName(0));
242             auto newSqueezeExpr = Expr::create(squeezeExpr->extra(), {newReluVar});
243             newSqueezeExpr->setName(squeezeExpr->name());
244             auto newSqueezeVar = Variable::create(newSqueezeExpr, 0);
245             newSqueezeVar->setName(squeezeExpr->outputName(0));
246 
247             Expr::replace(expr, newSqueezeExpr);
248             return true;
249         }
250 
251         return false;
252     };
253 
254     TemplateMerge::getInstance("Merge").insertTemplate("Conv1dSqueezeMove", match, transform,
255                                                        PASS_PRIORITY_HIGH);
256     return true;
257 }();
258 
259 }
260 } // namespace MNN
261