1 //
2 // Pooling3DTf.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2019/09/29.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "TfUtils.hpp"
10 #include "graph.pb.h"
11 #include "tfOpConverter.hpp"
12
13 DECLARE_OP_CONVERTER(Pooling3DTf);
14
opType()15 MNN::OpType Pooling3DTf::opType() {
16 return MNN::OpType_Pooling3D;
17 }
type()18 MNN::OpParameter Pooling3DTf::type() {
19 return MNN::OpParameter_Pool3D;
20 }
21
22 // input: tensor
run(MNN::OpT * dstOp,TmpNode * srcNode)23 void Pooling3DTf::run(MNN::OpT *dstOp, TmpNode *srcNode) {
24 auto pool3d = new MNN::Pool3DT;
25
26 tensorflow::AttrValue value;
27
28 int stride_h = 1;
29 int stride_w = 1;
30
31 if (srcNode->opType == "AvgPool3D") {
32 pool3d->type = MNN::PoolType_AVEPOOL;
33 } else if (srcNode->opType == "MaxPool3D") {
34 pool3d->type = MNN::PoolType_MAXPOOL;
35 } else {
36 DLOG(ERROR) << "Not Support This Pooling Type: " << srcNode->opType;
37 }
38
39 if (find_attr_value(srcNode->tfNode, "ksize", value)) {
40 std::vector<int32_t> kernels;
41 for (int i = 1; i < 4; ++i) {
42 kernels.push_back(value.list().i(i));
43 }
44 pool3d->kernels = kernels;
45 }
46
47 if (find_attr_value(srcNode->tfNode, "strides", value)) {
48 std::vector<int32_t> strides;
49 for (int i = 1; i < 4; ++i) {
50 strides.push_back(value.list().i(i));
51 }
52 pool3d->strides = strides;
53 }
54
55 if (find_attr_value(srcNode->tfNode, "padding", value)) {
56 if (value.s() == "VALID") {
57 pool3d->padType = MNN::PoolPadType_VALID;
58 pool3d->pads = std::vector<int32_t>(3, 0);
59 } else if (value.s() == "SAME") {
60 pool3d->padType = MNN::PoolPadType_SAME;
61 } else {
62 DLOG(ERROR) << "Not Support This Padding Mode";
63 }
64 }
65
66 dstOp->main.value = pool3d;
67 }
68
69 REGISTER_CONVERTER(Pooling3DTf, MaxPool3D);
70 REGISTER_CONVERTER(Pooling3DTf, AvgPool3D);
71