1 // 2 // StackTransform.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/11/20. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef StackTransform_hpp 10 #define StackTransform_hpp 11 12 #include <MNN/expr/ExprCreator.hpp> 13 #include "Transform.hpp" 14 15 namespace MNN { 16 namespace Train { 17 18 class MNN_PUBLIC StackTransform : public BatchTransform { transformBatch(std::vector<Example> batch)19 std::vector<Example> transformBatch(std::vector<Example> batch) override { 20 std::vector<std::vector<VARP>> batchData(batch[0].first.size()); 21 std::vector<std::vector<VARP>> batchTarget(batch[0].second.size()); 22 for (int i = 0; i < batch.size(); i++) { 23 for (int j = 0; j < batchData.size(); j++) { 24 batchData[j].emplace_back(batch[i].first[j]); 25 } 26 } 27 28 for (int i = 0; i < batch.size(); i++) { 29 for (int j = 0; j < batchTarget.size(); j++) { 30 batchTarget[j].emplace_back(batch[i].second[j]); 31 } 32 } 33 34 Example example; 35 for (int i = 0; i < batchData.size(); i++) { 36 example.first.emplace_back(_Stack(batchData[i], 0)); 37 } 38 for (int i = 0; i < batchTarget.size(); i++) { 39 example.second.emplace_back(_Stack(batchTarget[i], 0)); 40 } 41 42 return {example}; 43 } 44 }; 45 46 } // namespace Train 47 } // namespace MNN 48 49 #endif // StackTransform_hpp 50