1 //
2 //  TransformIm2Seq.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/09/05.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../PostTreatUtils.hpp"
10 
11 class TransformIm2Seq : public PostConverter {
12 public:
onExecute(std::unique_ptr<MNN::NetT> & net) const13     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
14         for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
15             auto& op = *iter;
16             if (op->type != MNN::OpType_Im2Seq) {
17                 iter++;
18                 continue;
19             }
20             auto inputId    = op->inputIndexes[0];
21             auto outputId   = op->outputIndexes[0];
22             auto outputname = net->tensorName[outputId];
23 
24             // New Reshape
25             MNN::OpT* reshapeT = new MNN::OpT;
26             reshapeT->name     = "____reshape____" + op->name;
27             auto reshapeP      = new MNN::ReshapeT;
28             reshapeP->dims.push_back(0);  // b
29             reshapeP->dims.push_back(-1); // c
30             reshapeP->dims.push_back(1);  // h
31             reshapeP->dims.push_back(0);  // w
32             reshapeT->main.type  = MNN::OpParameter_Reshape;
33             reshapeT->type       = MNN::OpType_Reshape;
34             reshapeT->main.value = reshapeP;
35 
36             // Net Tensor
37             net->tensorName.push_back(reshapeT->name);
38             int tempId = net->tensorName.size() - 1;
39 
40             reshapeT->inputIndexes.push_back(inputId);
41             reshapeT->outputIndexes.push_back(tempId);
42 
43             op->inputIndexes[0] = tempId;
44             op->type            = MNN::OpType_Permute;
45 
46             auto convP     = new MNN::PermuteT;
47             op->main.type  = MNN::OpParameter_Permute;
48             op->main.value = convP;
49             convP->dims.push_back(0);
50             convP->dims.push_back(3);
51             convP->dims.push_back(2);
52             convP->dims.push_back(1);
53 
54             iter = net->oplists.insert(iter, std::unique_ptr<MNN::OpT>(reshapeT));
55         }
56         return true;
57     }
58 };
59 static PostConverterRegister<TransformIm2Seq> __l("TransformIm2Seq");
60