1 //
2 //  TFliteBatchToSpace.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2021/04/19.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "MNN_generated.h"
10 #include "TFliteExtraManager.hpp"
11 
12 namespace MNN {
13 namespace Express {
copyInfo(SpaceBatchT * dst,std::vector<VARP> inputs)14 static void copyInfo(SpaceBatchT* dst, std::vector<VARP> inputs) {
15     MNN_ASSERT(inputs.size() == 3);
16     {
17         auto blockShape = inputs[1];
18         auto info = blockShape->getInfo();
19         auto ptr = blockShape->readMap<int>();
20         dst->blockShape.reset(new BlobT);
21         auto block        = dst->blockShape.get();
22         block->dataFormat = MNN_DATA_FORMAT_NHWC;
23         block->dataType = DataType_DT_INT32;
24         if (info != nullptr) {
25             block->dims = info->dim;
26             if (ptr != nullptr) {
27                 block->int32s.resize(info->size);
28                 ::memcpy(block->int32s.data(), ptr, info->size * sizeof(int32_t));
29             }
30         }
31     }
32     {
33         auto padding = inputs[2];
34         auto info = padding->getInfo();
35         auto ptr = padding->readMap<int>();
36         dst->padding.reset(new BlobT);
37         auto block        = dst->padding.get();
38         block->dataFormat = MNN_DATA_FORMAT_NHWC;
39         block->dataType = DataType_DT_INT32;
40         if (info != nullptr) {
41             block->dims = info->dim;
42             if (ptr != nullptr) {
43                 block->int32s.resize(info->size);
44                 ::memcpy(block->int32s.data(), ptr, info->size * sizeof(int32_t));
45             }
46         }
47     }
48 }
49 
50 class BatchToSpaceTransform : public TFliteExtraManager::Transform {
51 public:
onExecute(EXPRP expr) const52     virtual EXPRP onExecute(EXPRP expr) const override {
53         auto op = expr->get();
54         MNN_ASSERT(op->type() == OpType_Extra);
55         auto type   = op->main_as_Extra()->type()->str();
56         auto inputs = expr->inputs();
57         MNN_ASSERT(inputs.size() == 3);
58         std::unique_ptr<OpT> bsND(new OpT);
59         bsND->name       = expr->name();
60         bsND->type       = OpType_BatchToSpaceND;
61         bsND->main.type  = OpParameter_SpaceBatch;
62         bsND->main.value = new SpaceBatchT;
63         copyInfo(bsND->main.AsSpaceBatch(), inputs);
64         auto newExpr = Expr::create(bsND.get(), inputs, expr->outputSize());
65         return newExpr;
66     }
67 };
68 class SpaceToBatchTransform : public TFliteExtraManager::Transform {
69 public:
onExecute(EXPRP expr) const70     virtual EXPRP onExecute(EXPRP expr) const override {
71         auto op = expr->get();
72         MNN_ASSERT(op->type() == OpType_Extra);
73         auto type   = op->main_as_Extra()->type()->str();
74         auto inputs = expr->inputs();
75         MNN_ASSERT(inputs.size() == 3);
76         std::unique_ptr<OpT> bsND(new OpT);
77         bsND->name       = expr->name();
78         bsND->type       = OpType_SpaceToBatchND;
79         bsND->main.type  = OpParameter_SpaceBatch;
80         bsND->main.value = new SpaceBatchT;
81         copyInfo(bsND->main.AsSpaceBatch(), inputs);
82         auto newExpr = Expr::create(bsND.get(), inputs, expr->outputSize());
83         return newExpr;
84     }
85 };
__anonc9c05a4a0102() 86 static auto gRegister = []() {
87     TFliteExtraManager::get()->insert("BatchToSpace",
88                                       std::shared_ptr<TFliteExtraManager::Transform>(new BatchToSpaceTransform));
89     TFliteExtraManager::get()->insert("SpaceToBatch",
90                                       std::shared_ptr<TFliteExtraManager::Transform>(new SpaceToBatchTransform));
91     return true;
92 }();
93 } // namespace Express
94 } // namespace MNN
95