1 // 2 // TransformIm2Seq.cpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/09/05. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../PostTreatUtils.hpp" 10 11 class TransformIm2Seq : public PostConverter { 12 public: onExecute(std::unique_ptr<MNN::NetT> & net) const13 virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { 14 for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { 15 auto& op = *iter; 16 if (op->type != MNN::OpType_Im2Seq) { 17 iter++; 18 continue; 19 } 20 auto inputId = op->inputIndexes[0]; 21 auto outputId = op->outputIndexes[0]; 22 auto outputname = net->tensorName[outputId]; 23 24 // New Reshape 25 MNN::OpT* reshapeT = new MNN::OpT; 26 reshapeT->name = "____reshape____" + op->name; 27 auto reshapeP = new MNN::ReshapeT; 28 reshapeP->dims.push_back(0); // b 29 reshapeP->dims.push_back(-1); // c 30 reshapeP->dims.push_back(1); // h 31 reshapeP->dims.push_back(0); // w 32 reshapeT->main.type = MNN::OpParameter_Reshape; 33 reshapeT->type = MNN::OpType_Reshape; 34 reshapeT->main.value = reshapeP; 35 36 // Net Tensor 37 net->tensorName.push_back(reshapeT->name); 38 int tempId = net->tensorName.size() - 1; 39 40 reshapeT->inputIndexes.push_back(inputId); 41 reshapeT->outputIndexes.push_back(tempId); 42 43 op->inputIndexes[0] = tempId; 44 op->type = MNN::OpType_Permute; 45 46 auto convP = new MNN::PermuteT; 47 op->main.type = MNN::OpParameter_Permute; 48 op->main.value = convP; 49 convP->dims.push_back(0); 50 convP->dims.push_back(3); 51 convP->dims.push_back(2); 52 convP->dims.push_back(1); 53 54 iter = net->oplists.insert(iter, std::unique_ptr<MNN::OpT>(reshapeT)); 55 } 56 return true; 57 } 58 }; 59 static PostConverterRegister<TransformIm2Seq> __l("TransformIm2Seq"); 60