1 //
2 //  StridedSliceTf.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <string.h>
10 #include "TfUtils.hpp"
11 #include "tfOpConverter.hpp"
12 
13 #include "graph.pb.h"
14 
15 DECLARE_OP_CONVERTER(StridedSliceTf);
16 
opType()17 MNN::OpType StridedSliceTf::opType() {
18     return MNN::OpType_StridedSlice;
19 }
type()20 MNN::OpParameter StridedSliceTf::type() {
21     return MNN::OpParameter_StridedSliceParam;
22 }
23 
run(MNN::OpT * dstOp,TmpNode * srcNode)24 void StridedSliceTf::run(MNN::OpT *dstOp, TmpNode *srcNode) {
25     auto stridedslice = new MNN::StridedSliceParamT;
26 
27     tensorflow::AttrValue value;
28     find_attr_value(srcNode->tfNode, "begin_mask", value);
29     stridedslice->beginMask = value.i();
30 
31     find_attr_value(srcNode->tfNode, "end_mask", value);
32     stridedslice->endMask = value.i();
33 
34     find_attr_value(srcNode->tfNode, "ellipsis_mask", value);
35     stridedslice->ellipsisMask = value.i();
36 
37     find_attr_value(srcNode->tfNode, "new_axis_mask", value);
38     stridedslice->newAxisMask = value.i();
39 
40     find_attr_value(srcNode->tfNode, "shrink_axis_mask", value);
41     stridedslice->shrinkAxisMask = value.i();
42 
43     find_attr_value(srcNode->tfNode, "Index", value);
44     stridedslice->Index = (MNN::DataType)value.type();
45 
46     find_attr_value(srcNode->tfNode, "T", value);
47     stridedslice->T = (MNN::DataType)value.type();
48 
49     dstOp->main.value = stridedslice;
50 }
51 
52 REGISTER_CONVERTER(StridedSliceTf, StridedSlice);
53