1 //
2 //  Pooling3DTf.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/09/29.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "TfUtils.hpp"
10 #include "graph.pb.h"
11 #include "tfOpConverter.hpp"
12 
13 DECLARE_OP_CONVERTER(Pooling3DTf);
14 
opType()15 MNN::OpType Pooling3DTf::opType() {
16     return MNN::OpType_Pooling3D;
17 }
type()18 MNN::OpParameter Pooling3DTf::type() {
19     return MNN::OpParameter_Pool3D;
20 }
21 
22 // input: tensor
run(MNN::OpT * dstOp,TmpNode * srcNode)23 void Pooling3DTf::run(MNN::OpT *dstOp, TmpNode *srcNode) {
24     auto pool3d = new MNN::Pool3DT;
25 
26     tensorflow::AttrValue value;
27 
28     int stride_h      = 1;
29     int stride_w      = 1;
30 
31     if (srcNode->opType == "AvgPool3D") {
32         pool3d->type = MNN::PoolType_AVEPOOL;
33     } else if (srcNode->opType == "MaxPool3D") {
34         pool3d->type = MNN::PoolType_MAXPOOL;
35     } else {
36         DLOG(ERROR) << "Not Support This Pooling Type: " << srcNode->opType;
37     }
38 
39     if (find_attr_value(srcNode->tfNode, "ksize", value)) {
40         std::vector<int32_t> kernels;
41         for (int i = 1; i < 4; ++i) {
42             kernels.push_back(value.list().i(i));
43         }
44         pool3d->kernels = kernels;
45     }
46 
47     if (find_attr_value(srcNode->tfNode, "strides", value)) {
48         std::vector<int32_t> strides;
49         for (int i = 1; i < 4; ++i) {
50             strides.push_back(value.list().i(i));
51         }
52         pool3d->strides = strides;
53     }
54 
55     if (find_attr_value(srcNode->tfNode, "padding", value)) {
56         if (value.s() == "VALID") {
57             pool3d->padType = MNN::PoolPadType_VALID;
58             pool3d->pads = std::vector<int32_t>(3, 0);
59         } else if (value.s() == "SAME") {
60             pool3d->padType = MNN::PoolPadType_SAME;
61         } else {
62             DLOG(ERROR) << "Not Support This Padding Mode";
63         }
64     }
65 
66     dstOp->main.value = pool3d;
67 }
68 
69 REGISTER_CONVERTER(Pooling3DTf, MaxPool3D);
70 REGISTER_CONVERTER(Pooling3DTf, AvgPool3D);
71