1 //
2 // StridedSliceTf.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2019/01/31.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <string.h>
10 #include "TfUtils.hpp"
11 #include "tfOpConverter.hpp"
12
13 #include "graph.pb.h"
14
15 DECLARE_OP_CONVERTER(StridedSliceTf);
16
opType()17 MNN::OpType StridedSliceTf::opType() {
18 return MNN::OpType_StridedSlice;
19 }
type()20 MNN::OpParameter StridedSliceTf::type() {
21 return MNN::OpParameter_StridedSliceParam;
22 }
23
run(MNN::OpT * dstOp,TmpNode * srcNode)24 void StridedSliceTf::run(MNN::OpT *dstOp, TmpNode *srcNode) {
25 auto stridedslice = new MNN::StridedSliceParamT;
26
27 tensorflow::AttrValue value;
28 find_attr_value(srcNode->tfNode, "begin_mask", value);
29 stridedslice->beginMask = value.i();
30
31 find_attr_value(srcNode->tfNode, "end_mask", value);
32 stridedslice->endMask = value.i();
33
34 find_attr_value(srcNode->tfNode, "ellipsis_mask", value);
35 stridedslice->ellipsisMask = value.i();
36
37 find_attr_value(srcNode->tfNode, "new_axis_mask", value);
38 stridedslice->newAxisMask = value.i();
39
40 find_attr_value(srcNode->tfNode, "shrink_axis_mask", value);
41 stridedslice->shrinkAxisMask = value.i();
42
43 find_attr_value(srcNode->tfNode, "Index", value);
44 stridedslice->Index = (MNN::DataType)value.type();
45
46 find_attr_value(srcNode->tfNode, "T", value);
47 stridedslice->T = (MNN::DataType)value.type();
48
49 dstOp->main.value = stridedslice;
50 }
51
52 REGISTER_CONVERTER(StridedSliceTf, StridedSlice);
53