1 //
2 //  MergeToConvolution.hpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/09/05.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../PostTreatUtils.hpp"
10 using namespace MNN;
11 
12 class MergeToConvolution : public PostConverter {
13 public:
14     virtual bool merge2Convolution(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const = 0;
15 
16     virtual bool merge2Convolution3D(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const = 0;
17 
onExecute(std::unique_ptr<MNN::NetT> & net) const18     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
19         // Merge Layer
20         std::vector<MNN::OpT*> readyToDelete;
21         for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
22             MNN::OpT& currentOp = *(iter->get());
23             if (currentOp.type != MNN::OpType_Convolution
24                 && currentOp.type != MNN::OpType_Deconvolution
25                 && currentOp.type != MNN::OpType_ConvolutionDepthwise
26                 && currentOp.type != MNN::OpType_Convolution3D) {
27                 continue;
28             }
29             DCHECK(currentOp.outputIndexes.size() == 1) << "Conv output ERROR!";
30 
31             // merge Batchnorm/Relu/Relu6 to Convolution
32             std::vector<MNN::OpT*> nextOp = PostTreatUtils::_findOpByInputIndex(currentOp.outputIndexes[0], net.get());
33             while (1) {
34                 if (nextOp.size() != 1) {
35                     break;
36                 }
37                 const int nextOutputIndex = nextOp[0]->outputIndexes[0];
38                 bool succ;
39                 if (currentOp.type == MNN::OpType_Convolution3D) {
40                     succ = merge2Convolution3D(nextOp[0], &currentOp);
41                 } else {
42                     succ = merge2Convolution(nextOp[0], &currentOp);
43                 }
44                 if (PostTreatUtils::_isSingleInputOutput(nextOp[0]) && succ) {
45                     // LOG(INFO) << "Merge " << nextOp[0]->name.c_str()<< " into convolution: " <<
46                     // currentOp.name.c_str();
47                     currentOp.outputIndexes[0] = nextOp[0]->outputIndexes[0];
48                     readyToDelete.push_back(nextOp[0]);
49                     nextOp = PostTreatUtils::_findOpByInputIndex(nextOutputIndex, net.get());
50                 } else {
51                     break;
52                 }
53             }
54         }
55         for (auto op : readyToDelete) {
56             PostTreatUtils::_removeOpInNet(op, net.get());
57         }
58         return true;
59     }
60 };
61