1 //
2 // TFliteBatchToSpace.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2021/04/19.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "MNN_generated.h"
10 #include "TFliteExtraManager.hpp"
11
12 namespace MNN {
13 namespace Express {
copyInfo(SpaceBatchT * dst,std::vector<VARP> inputs)14 static void copyInfo(SpaceBatchT* dst, std::vector<VARP> inputs) {
15 MNN_ASSERT(inputs.size() == 3);
16 {
17 auto blockShape = inputs[1];
18 auto info = blockShape->getInfo();
19 auto ptr = blockShape->readMap<int>();
20 dst->blockShape.reset(new BlobT);
21 auto block = dst->blockShape.get();
22 block->dataFormat = MNN_DATA_FORMAT_NHWC;
23 block->dataType = DataType_DT_INT32;
24 if (info != nullptr) {
25 block->dims = info->dim;
26 if (ptr != nullptr) {
27 block->int32s.resize(info->size);
28 ::memcpy(block->int32s.data(), ptr, info->size * sizeof(int32_t));
29 }
30 }
31 }
32 {
33 auto padding = inputs[2];
34 auto info = padding->getInfo();
35 auto ptr = padding->readMap<int>();
36 dst->padding.reset(new BlobT);
37 auto block = dst->padding.get();
38 block->dataFormat = MNN_DATA_FORMAT_NHWC;
39 block->dataType = DataType_DT_INT32;
40 if (info != nullptr) {
41 block->dims = info->dim;
42 if (ptr != nullptr) {
43 block->int32s.resize(info->size);
44 ::memcpy(block->int32s.data(), ptr, info->size * sizeof(int32_t));
45 }
46 }
47 }
48 }
49
50 class BatchToSpaceTransform : public TFliteExtraManager::Transform {
51 public:
onExecute(EXPRP expr) const52 virtual EXPRP onExecute(EXPRP expr) const override {
53 auto op = expr->get();
54 MNN_ASSERT(op->type() == OpType_Extra);
55 auto type = op->main_as_Extra()->type()->str();
56 auto inputs = expr->inputs();
57 MNN_ASSERT(inputs.size() == 3);
58 std::unique_ptr<OpT> bsND(new OpT);
59 bsND->name = expr->name();
60 bsND->type = OpType_BatchToSpaceND;
61 bsND->main.type = OpParameter_SpaceBatch;
62 bsND->main.value = new SpaceBatchT;
63 copyInfo(bsND->main.AsSpaceBatch(), inputs);
64 auto newExpr = Expr::create(bsND.get(), inputs, expr->outputSize());
65 return newExpr;
66 }
67 };
68 class SpaceToBatchTransform : public TFliteExtraManager::Transform {
69 public:
onExecute(EXPRP expr) const70 virtual EXPRP onExecute(EXPRP expr) const override {
71 auto op = expr->get();
72 MNN_ASSERT(op->type() == OpType_Extra);
73 auto type = op->main_as_Extra()->type()->str();
74 auto inputs = expr->inputs();
75 MNN_ASSERT(inputs.size() == 3);
76 std::unique_ptr<OpT> bsND(new OpT);
77 bsND->name = expr->name();
78 bsND->type = OpType_SpaceToBatchND;
79 bsND->main.type = OpParameter_SpaceBatch;
80 bsND->main.value = new SpaceBatchT;
81 copyInfo(bsND->main.AsSpaceBatch(), inputs);
82 auto newExpr = Expr::create(bsND.get(), inputs, expr->outputSize());
83 return newExpr;
84 }
85 };
__anonc9c05a4a0102() 86 static auto gRegister = []() {
87 TFliteExtraManager::get()->insert("BatchToSpace",
88 std::shared_ptr<TFliteExtraManager::Transform>(new BatchToSpaceTransform));
89 TFliteExtraManager::get()->insert("SpaceToBatch",
90 std::shared_ptr<TFliteExtraManager::Transform>(new SpaceToBatchTransform));
91 return true;
92 }();
93 } // namespace Express
94 } // namespace MNN
95