1 //
2 //  NPUStridedSlice.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/09/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "NPUStridedSlice.hpp"
10 #include "NPUBackend.hpp"
11 
12 using namespace std;
13 
14 namespace MNN {
15 
NPUStridedSlice(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)16 NPUStridedSlice::NPUStridedSlice(Backend *b, const Op *op, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) : MNN::NPUCommonExecution(b,op)
17 {
18     bool isConst1 = TensorUtils::getDescribe(inputs[1])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
19     bool isConst2 = TensorUtils::getDescribe(inputs[2])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
20     bool isConst3 = TensorUtils::getDescribe(inputs[3])->usage==Tensor::InsideDescribe::Usage::CONSTANT;
21     auto opName = mOp->name()->str();
22     Tensor *begin   = inputs[1];
23     Tensor *end     = inputs[2];
24     Tensor *strided = inputs[3];
25 
26     if(isConst1 == true) {
27         auto beginShape = convertShapeConstValue(begin, 0);
28         mConst_b = ge::op::Const(opName + "_b_const");
29         {
30             ge::TensorDesc fdesc(ge::Shape({4}), ge::DT_INT32);
31             ge::TensorPtr filter = std::make_shared<ge::Tensor>();
32             filter->SetTensorDesc(fdesc);
33             filter->SetData((uint8_t *)&beginShape[0], 4*sizeof(int32_t));
34             mConst_b.set_attr_value(filter);
35         }
36     }
37 
38     if(isConst2 == true) {
39         auto endShape = convertShapeConstValue(end, 0);
40         mConst_e = ge::op::Const(opName + "_e_const");
41         {
42             ge::TensorDesc fdesc(ge::Shape({4}),  ge::DT_INT32);
43             ge::TensorPtr filter = std::make_shared<ge::Tensor>();
44             filter->SetTensorDesc(fdesc);
45             filter->SetData((uint8_t *)&endShape[0], 4*sizeof(int32_t));
46             mConst_e.set_attr_value(filter);
47         }
48     }
49 
50     if(isConst3 == true) {
51         auto stridedShape = convertShapeConstValue(strided);
52         mConst_s = ge::op::Const(opName + "_s_const");
53         {
54             ge::TensorDesc fdesc(ge::Shape({4}), ge::DT_INT32);
55             ge::TensorPtr filter = std::make_shared<ge::Tensor>();
56             filter->SetTensorDesc(fdesc);
57             filter->SetData((uint8_t *)&stridedShape[0], 4*sizeof(int32_t));
58             mConst_s.set_attr_value(filter);
59         }
60     }
61     auto parameter = mOp->main_as_StridedSliceParam();
62     beginMask = convertMask(begin, parameter->beginMask(),1);
63     endMask = convertMask(begin, parameter->endMask(),1);
64     ellipsisMask = parameter->ellipsisMask(); //框架未使用
65     newAxisMask = parameter->newAxisMask();
66     shrinkAxisMask = convertMask(begin, parameter->shrinkAxisMask());
67 }
68 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)69 ErrorCode NPUStridedSlice::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
70     mNpuBackend->setNetworkInput(inputs, mOp);
71 
72     auto opName = mOp->name()->str();
73     auto param  = mOp->main_as_Axis();
74 
75     shared_ptr<ge::op::StridedSlice> stride_slice(new ge::op::StridedSlice(opName));
76 
77     auto inputIndex = mOp->inputIndexes()->data()[0];
78     auto iops       = mNpuBackend->mGrapMap[inputIndex]; // x
79     auto xOp        = iops.back().first;
80 
81     auto parameter = mOp->main_as_StridedSliceParam();
82 
83     (*stride_slice)
84         .set_input_x(*xOp.get())
85         .set_input_begin(mConst_b)
86         .set_input_end(mConst_e)
87         .set_input_strides(mConst_s)
88         .set_attr_begin_mask(beginMask)
89         .set_attr_end_mask(endMask)
90         .set_attr_ellipsis_mask(ellipsisMask)
91         .set_attr_new_axis_mask(newAxisMask)
92         .set_attr_shrink_axis_mask(shrinkAxisMask);
93 
94     mNpuBackend->setOutputOps(mOp, {stride_slice}, outputs);
95 
96     return NO_ERROR;
97 }
98 
99 NPUCreatorRegister<TypedCreator<NPUStridedSlice>> __stride_slice_op(OpType_StridedSlice);
100 
101 } // namespace MNN
102