1 //
2 // GeometrySlice.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/04/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 namespace MNN {
12 class GeometrySliceTF : public GeometryComputer {
13 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const14 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
15 Context& context, CommandBuffer& res) const override {
16 auto input = inputs[0];
17 // these two inputs should be const
18 auto begin_tensor = inputs[1];
19
20 auto beginPtr = begin_tensor->host<int32_t>();
21
22 std::vector<int> seperateDimIndexes;
23 std::vector<int> outputStrides(input->buffer().dimensions);
24 auto output = outputs[0];
25 int stride = 1;
26 int srcOffset = 0;
27 for (int i = input->buffer().dimensions - 1; i >= 0; --i) {
28 outputStrides[i] = stride;
29 srcOffset += beginPtr[i] * stride;
30 stride *= input->length(i);
31 }
32 for (int i = 0; i < output->buffer().dimensions; ++i) {
33 if (1 != output->length(i)) {
34 seperateDimIndexes.emplace_back(i);
35 }
36 }
37 auto outputDes = TensorUtils::getDescribe(output);
38 int basicStride = 1;
39 // Compute inside, outside, axis
40 int inside = 1;
41 int insideStride = 0;
42 int outside = 1;
43 int outsideStride = 0;
44 int axis = 1;
45 int axisStride = 0;
46 int breakAxis = 0;
47 int remainSize = 1;
48 {
49 if (seperateDimIndexes.size() >= 1) {
50 auto index = seperateDimIndexes[seperateDimIndexes.size() - 1];
51 inside = output->length(index);
52 insideStride = outputStrides[index];
53 }
54 if (seperateDimIndexes.size() >= 2) {
55 auto index = seperateDimIndexes[seperateDimIndexes.size() - 2];
56 axis = output->length(index);
57 axisStride = outputStrides[index];
58 }
59 if (seperateDimIndexes.size() >= 3) {
60 auto index = seperateDimIndexes[seperateDimIndexes.size() - 3];
61 outside = output->length(index);
62 outsideStride = outputStrides[index];
63 breakAxis = (int)seperateDimIndexes.size() - 3;
64 for (int i = 0; i < seperateDimIndexes.size() - 3; ++i) {
65 remainSize *= output->length(seperateDimIndexes[i]);
66 }
67 }
68 }
69 outputDes->regions.resize(remainSize);
70 std::vector<int32_t> mod(breakAxis);
71 for (int i = 0; i < breakAxis; ++i) {
72 int value = 1;
73 for (int j = i + 1; j < breakAxis; ++j) {
74 auto index = seperateDimIndexes[j];
75 value *= output->length(index);
76 }
77 mod[i] = value;
78 }
79 for (int indice = 0; indice < remainSize; ++indice) {
80 int value = indice;
81 int inputOffset = 0;
82 for (int i = 0; i < breakAxis; ++i) {
83 auto coordinate = value / mod[i];
84 auto index = seperateDimIndexes[i];
85 inputOffset += (coordinate)*outputStrides[index];
86 value = value % mod[i];
87 }
88 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
89 Tensor::InsideDescribe::Region& slice = outputDes->regions[indice];
90 slice.src.offset = inputOffset + srcOffset;
91 slice.src.stride[0] = outsideStride * basicStride;
92 slice.size[0] = outside;
93 slice.src.stride[1] = axisStride * basicStride;
94 slice.size[1] = axis;
95 slice.src.stride[2] = insideStride * basicStride;
96 slice.size[2] = inside;
97 slice.origin = input;
98 slice.dst.offset = indice * outside * axis * inside;
99 slice.dst.stride[0] = axis * inside;
100 slice.dst.stride[1] = inside;
101 slice.dst.stride[2] = 1;
102 }
103 return true;
104 }
105 };
106 class GeometrySlice : public GeometryComputer {
107 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const108 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
109 Context& context, CommandBuffer& res) const override {
110 auto input = inputs[0];
111 int axis = 0;
112 bool inputFix = false;
113 if (op->type() == OpType_Slice) {
114 auto slice = op->main_as_Slice();
115 axis = slice->axis();
116 } else if (op->type() == OpType_Unpack) {
117 axis = op->main_as_Axis()->axis();
118 inputFix = true;
119 }
120
121 if (axis < 0) {
122 axis = axis + input->dimensions();
123 }
124 int outside = 1;
125 int inside = 1;
126 for (int i = 0; i < axis; ++i) {
127 outside *= input->length(i);
128 }
129 for (int i = axis + 1; i < input->dimensions(); ++i) {
130 inside *= input->length(i);
131 }
132 auto inputZero = input->elementSize() <= 0;
133 int offset = 0;
134 for (int i = 0; i < outputs.size(); ++i) {
135 auto outputDes = TensorUtils::getDescribe(outputs[i]);
136 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
137 if (inputZero) {
138 outputDes->regions.clear();
139 continue;
140 }
141 outputDes->regions.resize(1);
142 auto& slice = outputDes->regions[0];
143 slice.src.offset = offset * inside;
144 slice.origin = input;
145 slice.size[0] = outside;
146 slice.size[2] = inside;
147 slice.src.stride[0] = input->length(axis) * inside;
148 slice.src.stride[1] = inside;
149 slice.src.stride[2] = 1;
150 if (inputFix) {
151 slice.size[1] = 1;
152 offset += 1;
153 } else {
154 slice.size[1] = outputs[i]->length(axis);
155 offset += outputs[i]->length(axis);
156 }
157 slice.dst.offset = 0;
158 slice.dst.stride[0] = inside * slice.size[1];
159 slice.dst.stride[1] = slice.size[2];
160 slice.dst.stride[2] = 1;
161 }
162 return true;
163 }
164 };
165
_create()166 static void _create() {
167 std::shared_ptr<GeometryComputer> comp(new GeometrySlice);
168 GeometryComputer::registerGeometryComputer(comp, {OpType_Slice, OpType_Unpack});
169 std::shared_ptr<GeometryComputer> comp2(new GeometrySliceTF);
170 GeometryComputer::registerGeometryComputer(comp2, {OpType_SliceTf});
171 }
172
173 REGISTER_GEOMETRY(GeometrySlice, _create);
174
175 } // namespace MNN
176