1 // 2 // MergeToConvolution.hpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/09/05. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../PostTreatUtils.hpp" 10 using namespace MNN; 11 12 class MergeToConvolution : public PostConverter { 13 public: 14 virtual bool merge2Convolution(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const = 0; 15 16 virtual bool merge2Convolution3D(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) const = 0; 17 onExecute(std::unique_ptr<MNN::NetT> & net) const18 virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { 19 // Merge Layer 20 std::vector<MNN::OpT*> readyToDelete; 21 for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) { 22 MNN::OpT& currentOp = *(iter->get()); 23 if (currentOp.type != MNN::OpType_Convolution 24 && currentOp.type != MNN::OpType_Deconvolution 25 && currentOp.type != MNN::OpType_ConvolutionDepthwise 26 && currentOp.type != MNN::OpType_Convolution3D) { 27 continue; 28 } 29 DCHECK(currentOp.outputIndexes.size() == 1) << "Conv output ERROR!"; 30 31 // merge Batchnorm/Relu/Relu6 to Convolution 32 std::vector<MNN::OpT*> nextOp = PostTreatUtils::_findOpByInputIndex(currentOp.outputIndexes[0], net.get()); 33 while (1) { 34 if (nextOp.size() != 1) { 35 break; 36 } 37 const int nextOutputIndex = nextOp[0]->outputIndexes[0]; 38 bool succ; 39 if (currentOp.type == MNN::OpType_Convolution3D) { 40 succ = merge2Convolution3D(nextOp[0], ¤tOp); 41 } else { 42 succ = merge2Convolution(nextOp[0], ¤tOp); 43 } 44 if (PostTreatUtils::_isSingleInputOutput(nextOp[0]) && succ) { 45 // LOG(INFO) << "Merge " << nextOp[0]->name.c_str()<< " into convolution: " << 46 // currentOp.name.c_str(); 47 currentOp.outputIndexes[0] = nextOp[0]->outputIndexes[0]; 48 readyToDelete.push_back(nextOp[0]); 49 nextOp = PostTreatUtils::_findOpByInputIndex(nextOutputIndex, net.get()); 50 } else { 51 break; 52 } 53 } 54 } 55 for (auto op : readyToDelete) { 56 PostTreatUtils::_removeOpInNet(op, net.get()); 57 } 58 return true; 59 } 60 }; 61