1 //
2 //  GeometrySlice.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/04/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "geometry/GeometryComputer.hpp"
10 #include "core/OpCommonUtils.hpp"
11 namespace MNN {
12 class GeometrySliceTF : public GeometryComputer {
13 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const14     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
15                            Context& context, CommandBuffer& res) const override {
16         auto input = inputs[0];
17         // these two inputs should be const
18         auto begin_tensor = inputs[1];
19 
20         auto beginPtr = begin_tensor->host<int32_t>();
21 
22         std::vector<int> seperateDimIndexes;
23         std::vector<int> outputStrides(input->buffer().dimensions);
24         auto output   = outputs[0];
25         int stride    = 1;
26         int srcOffset = 0;
27         for (int i = input->buffer().dimensions - 1; i >= 0; --i) {
28             outputStrides[i] = stride;
29             srcOffset += beginPtr[i] * stride;
30             stride *= input->length(i);
31         }
32         for (int i = 0; i < output->buffer().dimensions; ++i) {
33             if (1 != output->length(i)) {
34                 seperateDimIndexes.emplace_back(i);
35             }
36         }
37         auto outputDes  = TensorUtils::getDescribe(output);
38         int basicStride = 1;
39         // Compute inside, outside, axis
40         int inside        = 1;
41         int insideStride  = 0;
42         int outside       = 1;
43         int outsideStride = 0;
44         int axis          = 1;
45         int axisStride    = 0;
46         int breakAxis     = 0;
47         int remainSize    = 1;
48         {
49             if (seperateDimIndexes.size() >= 1) {
50                 auto index   = seperateDimIndexes[seperateDimIndexes.size() - 1];
51                 inside       = output->length(index);
52                 insideStride = outputStrides[index];
53             }
54             if (seperateDimIndexes.size() >= 2) {
55                 auto index = seperateDimIndexes[seperateDimIndexes.size() - 2];
56                 axis       = output->length(index);
57                 axisStride = outputStrides[index];
58             }
59             if (seperateDimIndexes.size() >= 3) {
60                 auto index    = seperateDimIndexes[seperateDimIndexes.size() - 3];
61                 outside       = output->length(index);
62                 outsideStride = outputStrides[index];
63                 breakAxis     = (int)seperateDimIndexes.size() - 3;
64                 for (int i = 0; i < seperateDimIndexes.size() - 3; ++i) {
65                     remainSize *= output->length(seperateDimIndexes[i]);
66                 }
67             }
68         }
69         outputDes->regions.resize(remainSize);
70         std::vector<int32_t> mod(breakAxis);
71         for (int i = 0; i < breakAxis; ++i) {
72             int value = 1;
73             for (int j = i + 1; j < breakAxis; ++j) {
74                 auto index = seperateDimIndexes[j];
75                 value *= output->length(index);
76             }
77             mod[i] = value;
78         }
79         for (int indice = 0; indice < remainSize; ++indice) {
80             int value       = indice;
81             int inputOffset = 0;
82             for (int i = 0; i < breakAxis; ++i) {
83                 auto coordinate = value / mod[i];
84                 auto index      = seperateDimIndexes[i];
85                 inputOffset += (coordinate)*outputStrides[index];
86                 value = value % mod[i];
87             }
88             outputDes->memoryType                 = Tensor::InsideDescribe::MEMORY_VIRTUAL;
89             Tensor::InsideDescribe::Region& slice = outputDes->regions[indice];
90             slice.src.offset                      = inputOffset + srcOffset;
91             slice.src.stride[0]                   = outsideStride * basicStride;
92             slice.size[0]                         = outside;
93             slice.src.stride[1]                   = axisStride * basicStride;
94             slice.size[1]                         = axis;
95             slice.src.stride[2]                   = insideStride * basicStride;
96             slice.size[2]                         = inside;
97             slice.origin                          = input;
98             slice.dst.offset                      = indice * outside * axis * inside;
99             slice.dst.stride[0]                   = axis * inside;
100             slice.dst.stride[1]                   = inside;
101             slice.dst.stride[2]                   = 1;
102         }
103         return true;
104     }
105 };
106 class GeometrySlice : public GeometryComputer {
107 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const108     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
109                            Context& context, CommandBuffer& res) const override {
110         auto input    = inputs[0];
111         int axis      = 0;
112         bool inputFix = false;
113         if (op->type() == OpType_Slice) {
114             auto slice = op->main_as_Slice();
115             axis       = slice->axis();
116         } else if (op->type() == OpType_Unpack) {
117             axis     = op->main_as_Axis()->axis();
118             inputFix = true;
119         }
120 
121         if (axis < 0) {
122             axis = axis + input->dimensions();
123         }
124         int outside = 1;
125         int inside  = 1;
126         for (int i = 0; i < axis; ++i) {
127             outside *= input->length(i);
128         }
129         for (int i = axis + 1; i < input->dimensions(); ++i) {
130             inside *= input->length(i);
131         }
132         auto inputZero = input->elementSize() <= 0;
133         int offset = 0;
134         for (int i = 0; i < outputs.size(); ++i) {
135             auto outputDes = TensorUtils::getDescribe(outputs[i]);
136             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
137             if (inputZero) {
138                 outputDes->regions.clear();
139                 continue;
140             }
141             outputDes->regions.resize(1);
142             auto& slice           = outputDes->regions[0];
143             slice.src.offset      = offset * inside;
144             slice.origin          = input;
145             slice.size[0]         = outside;
146             slice.size[2]         = inside;
147             slice.src.stride[0]   = input->length(axis) * inside;
148             slice.src.stride[1]   = inside;
149             slice.src.stride[2]   = 1;
150             if (inputFix) {
151                 slice.size[1] = 1;
152                 offset += 1;
153             } else {
154                 slice.size[1] = outputs[i]->length(axis);
155                 offset += outputs[i]->length(axis);
156             }
157             slice.dst.offset = 0;
158             slice.dst.stride[0] = inside * slice.size[1];
159             slice.dst.stride[1] = slice.size[2];
160             slice.dst.stride[2] = 1;
161         }
162         return true;
163     }
164 };
165 
_create()166 static void _create() {
167     std::shared_ptr<GeometryComputer> comp(new GeometrySlice);
168     GeometryComputer::registerGeometryComputer(comp, {OpType_Slice, OpType_Unpack});
169     std::shared_ptr<GeometryComputer> comp2(new GeometrySliceTF);
170     GeometryComputer::registerGeometryComputer(comp2, {OpType_SliceTf});
171 }
172 
173 REGISTER_GEOMETRY(GeometrySlice, _create);
174 
175 } // namespace MNN
176