1 //
2 //  StackTransform.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/11/20.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef StackTransform_hpp
10 #define StackTransform_hpp
11 
12 #include <MNN/expr/ExprCreator.hpp>
13 #include "Transform.hpp"
14 
15 namespace MNN {
16 namespace Train {
17 
18 class MNN_PUBLIC StackTransform : public BatchTransform {
transformBatch(std::vector<Example> batch)19     std::vector<Example> transformBatch(std::vector<Example> batch) override {
20         std::vector<std::vector<VARP>> batchData(batch[0].first.size());
21         std::vector<std::vector<VARP>> batchTarget(batch[0].second.size());
22         for (int i = 0; i < batch.size(); i++) {
23             for (int j = 0; j < batchData.size(); j++) {
24                 batchData[j].emplace_back(batch[i].first[j]);
25             }
26         }
27 
28         for (int i = 0; i < batch.size(); i++) {
29             for (int j = 0; j < batchTarget.size(); j++) {
30                 batchTarget[j].emplace_back(batch[i].second[j]);
31             }
32         }
33 
34         Example example;
35         for (int i = 0; i < batchData.size(); i++) {
36             example.first.emplace_back(_Stack(batchData[i], 0));
37         }
38         for (int i = 0; i < batchTarget.size(); i++) {
39             example.second.emplace_back(_Stack(batchTarget[i], 0));
40         }
41 
42         return {example};
43     }
44 };
45 
46 } // namespace Train
47 } // namespace MNN
48 
49 #endif // StackTransform_hpp
50