1 //
2 // NPUStridedSlice.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/09/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "NPUStridedSlice.hpp"
10 #include "NPUBackend.hpp"
11
12 using namespace std;
13
14 namespace MNN {
15
NPUStridedSlice(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)16 NPUStridedSlice::NPUStridedSlice(Backend *b, const Op *op, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) : MNN::NPUCommonExecution(b,op)
17 {
18 bool isConst1 = TensorUtils::getDescribe(inputs[1])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
19 bool isConst2 = TensorUtils::getDescribe(inputs[2])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
20 bool isConst3 = TensorUtils::getDescribe(inputs[3])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
21 auto opName = mOp->name()->str();
22 Tensor *begin = inputs[1];
23 Tensor *end = inputs[2];
24 Tensor *strided = inputs[3];
25
26 if(isConst1 == true) {
27 auto beginShape = convertShapeConstValue(begin, 0);
28 mConst_b = ge::op::Const(opName + "_b_const");
29 {
30 ge::TensorDesc fdesc(ge::Shape({4}), ge::DT_INT32);
31 ge::TensorPtr filter = std::make_shared<ge::Tensor>();
32 filter->SetTensorDesc(fdesc);
33 filter->SetData((uint8_t *)&beginShape[0], 4*sizeof(int32_t));
34 mConst_b.set_attr_value(filter);
35 }
36 }
37
38 if(isConst2 == true) {
39 auto endShape = convertShapeConstValue(end, 0);
40 mConst_e = ge::op::Const(opName + "_e_const");
41 {
42 ge::TensorDesc fdesc(ge::Shape({4}), ge::DT_INT32);
43 ge::TensorPtr filter = std::make_shared<ge::Tensor>();
44 filter->SetTensorDesc(fdesc);
45 filter->SetData((uint8_t *)&endShape[0], 4*sizeof(int32_t));
46 mConst_e.set_attr_value(filter);
47 }
48 }
49
50 if(isConst3 == true) {
51 auto stridedShape = convertShapeConstValue(strided);
52 mConst_s = ge::op::Const(opName + "_s_const");
53 {
54 ge::TensorDesc fdesc(ge::Shape({4}), ge::DT_INT32);
55 ge::TensorPtr filter = std::make_shared<ge::Tensor>();
56 filter->SetTensorDesc(fdesc);
57 filter->SetData((uint8_t *)&stridedShape[0], 4*sizeof(int32_t));
58 mConst_s.set_attr_value(filter);
59 }
60 }
61 auto parameter = mOp->main_as_StridedSliceParam();
62 beginMask = convertMask(begin, parameter->beginMask(),1);
63 endMask = convertMask(begin, parameter->endMask(),1);
64 ellipsisMask = parameter->ellipsisMask(); //框架未使用
65 newAxisMask = parameter->newAxisMask();
66 shrinkAxisMask = convertMask(begin, parameter->shrinkAxisMask());
67 }
68
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)69 ErrorCode NPUStridedSlice::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
70 mNpuBackend->setNetworkInput(inputs, mOp);
71
72 auto opName = mOp->name()->str();
73 auto param = mOp->main_as_Axis();
74
75 shared_ptr<ge::op::StridedSlice> stride_slice(new ge::op::StridedSlice(opName));
76
77 auto inputIndex = mOp->inputIndexes()->data()[0];
78 auto iops = mNpuBackend->mGrapMap[inputIndex]; // x
79 auto xOp = iops.back().first;
80
81 auto parameter = mOp->main_as_StridedSliceParam();
82
83 (*stride_slice)
84 .set_input_x(*xOp.get())
85 .set_input_begin(mConst_b)
86 .set_input_end(mConst_e)
87 .set_input_strides(mConst_s)
88 .set_attr_begin_mask(beginMask)
89 .set_attr_end_mask(endMask)
90 .set_attr_ellipsis_mask(ellipsisMask)
91 .set_attr_new_axis_mask(newAxisMask)
92 .set_attr_shrink_axis_mask(shrinkAxisMask);
93
94 mNpuBackend->setOutputOps(mOp, {stride_slice}, outputs);
95
96 return NO_ERROR;
97 }
98
99 NPUCreatorRegister<TypedCreator<NPUStridedSlice>> __stride_slice_op(OpType_StridedSlice);
100
101 } // namespace MNN
102