1 //
2 //  ShapeStridedSlice.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/10.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <algorithm>
10 #include <array>
11 #include "shape/SizeComputer.hpp"
12 #include "core/Macro.h"
13 #include "core/TensorUtils.hpp"
14 namespace MNN {
15 class StridedSliceComputer : public SizeComputer {
16 public:
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const17     virtual bool onComputeSize(const MNN::Op *op, const std::vector<Tensor *> &inputs,
18                                const std::vector<Tensor *> &outputs) const override {
19         MNN_ASSERT(4 == inputs.size());
20         MNN_ASSERT(1 == outputs.size());
21 
22         Tensor *input            = inputs[0];
23         const int inputDim = input->buffer().dimensions;
24         if (inputDim <= 0 || inputDim > MNN_MAX_TENSOR_DIM) {
25             return false;
26         }
27         auto parameter = op->main_as_StridedSliceParam();
28         int32_t beginMask = parameter->beginMask();
29         int32_t endMask = parameter->endMask();
30         int32_t shrinkAxisMask = parameter->shrinkAxisMask();
31         int32_t ellipsisMask = parameter->ellipsisMask();
32         int32_t newAxisMask = parameter->newAxisMask();
33         if (ellipsisMask && (ellipsisMask & (ellipsisMask - 1))) {
34             MNN_ERROR("only one non-zero bit is allowed in ellipsisMask\n");
35             return false;
36         }
37 
38         Tensor *begin   = inputs[1];
39         Tensor *end     = inputs[2];
40         Tensor *strided = inputs[3];
41         auto output    = outputs[0];
42 
43         MNN_ASSERT(begin->buffer().dimensions == end->buffer().dimensions &&
44                    begin->buffer().dimensions == strided->buffer().dimensions);
45 
46         int32_t inputShape[MNN_MAX_TENSOR_DIM] = { 0 };
47         int32_t begins[MNN_MAX_TENSOR_DIM] = { 0 };
48         int32_t ends[MNN_MAX_TENSOR_DIM] = { 0 };
49         int32_t strides[MNN_MAX_TENSOR_DIM] = { 0 };
50         int32_t beginMasks[MNN_MAX_TENSOR_DIM] = { 0 };
51         int32_t endMasks[MNN_MAX_TENSOR_DIM] = { 0 };
52         int32_t shrinkAxisMasks[MNN_MAX_TENSOR_DIM] = { 0 };
53         int32_t newAxisMasks[MNN_MAX_TENSOR_DIM] = { 0 };
54         int strideSize = begin->length(0);
55         for (int i = 0; i < inputDim; i++) {
56             inputShape[i] = input->length(i);
57         }
58         for (int i = 0; i < strideSize; i++) {
59             beginMasks[i] = beginMask & (1 << i);
60         }
61         for (int i = 0; i < strideSize; i++) {
62             endMasks[i] = endMask & (1 << i);
63         }
64         for (int i = 0; i < strideSize; i++) {
65             shrinkAxisMasks[i] = shrinkAxisMask & (1 << i);
66         }
67         for (int i = 0; i < strideSize; i++) {
68             newAxisMasks[i] = newAxisMask & (1 << i);
69         }
70 
71         // deal ellipsis, expand strides info
72         if (ellipsisMask > 0) {
73             int32_t beginMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 };
74             int32_t endMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 };
75             int32_t shrinkAxisMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 };
76             int32_t newAxisMasksTmp[MNN_MAX_TENSOR_DIM] = { 0 };
77             // expand stride info
78             int ellipsisPos = -1;
79             for (int i = 0; i < strideSize; i++) {
80                 int temp = ellipsisMask & (1 << i);
81                 if (temp != 0) {
82                     ellipsisPos = i;
83                     break;
84                 }
85             }
86             MNN_ASSERT(ellipsisPos >= 0 && ellipsisPos < strideSize);
87             /*
88              Example: foo's dim is [2, 3, 4, 5, 6, 7], foo[0:2, :, 3:5, 3:6]:
89                 1. strideSize = 4, inputDim = 6, ellipsis = 2(0010)
90                 2. left part: 0:2, right part: 3:5, 3:6
91                 3. expand: foo[0:2, 0:3, 0:4, 3:5, 3:6]
92              */
93             int ellpsisSize = inputDim - strideSize, strideIdx = 0;
94             for (int i = 0; i < inputDim; i++) {
95                 if (i == ellipsisPos) {
96                     strideIdx++;
97                 }
98                 if (i >= ellipsisPos && i <= ellipsisPos + ellpsisSize) {
99                     begins[i] = 0;
100                     ends[i] = inputShape[i];
101                     strides[i] = 1;
102                     beginMasksTmp[i] = 0;
103                     endMasksTmp[i] = 0;
104                     shrinkAxisMasksTmp[i] = 0;
105                 } else {
106                     begins[i] = begin->host<int32_t>()[strideIdx];
107                     ends[i] = end->host<int32_t>()[strideIdx];
108                     strides[i] = strided->host<int32_t>()[strideIdx];
109                     beginMasksTmp[i] = beginMasks[strideIdx];
110                     endMasksTmp[i] = endMasks[strideIdx];
111                     shrinkAxisMasksTmp[i] = shrinkAxisMasks[strideIdx];
112                     newAxisMasksTmp[i] = newAxisMasks[strideIdx++];
113                 }
114             }
115             for (int i = 0; i < inputDim; i++) {
116                 beginMasks[i] = beginMasksTmp[i];
117                 endMasks[i] = endMasksTmp[i];
118                 shrinkAxisMasks[i] = shrinkAxisMasksTmp[i];
119                 newAxisMasks[i] = newAxisMasksTmp[i];
120             }
121             strideSize = inputDim;
122         } else {
123             for (int i = 0; i < strideSize; i++) {
124                 begins[i] = begin->host<int>()[i];
125                 ends[i] = end->host<int>()[i];
126                 strides[i] = strided->host<int>()[i];
127             }
128         }
129 
130         int32_t beginShape[MNN_MAX_TENSOR_DIM];
131         int32_t endShape[MNN_MAX_TENSOR_DIM];
132         int32_t stridedShape[MNN_MAX_TENSOR_DIM];
133         int32_t outputShape[MNN_MAX_TENSOR_DIM];
134         int32_t outputShapeShrinked[MNN_MAX_TENSOR_DIM];
135 
136         int outputShapeSize = 0;
137         int outputShapeShrinkSize = 0;
138         int strideDealDims = 0;
139 
140         auto beginAndEndShapeLimit = [](int shape, int dimSize, bool exclusive) -> int {
141             int maxShape = dimSize - 1, minShape = -dimSize;
142             if (exclusive) {
143                 ++maxShape;
144                 --minShape;
145             }
146             shape = (shape > maxShape ? maxShape : shape);
147             shape = (shape < minShape ? minShape : shape);
148             if (shape < 0) {
149                 shape += dimSize;
150             }
151             return shape;
152         };
153 
154         int inputDimOffset = 0;
155         for (int i = 0; i < strideSize; i++) {
156             if (newAxisMasks[i] > 0) {
157                 outputShape[outputShapeSize] = 1;
158                 outputShapeSize++;
159                 outputShapeShrinked[outputShapeShrinkSize] = 1;
160                 outputShapeShrinkSize++;
161                 continue;
162             }
163             auto inputDim = inputShape[inputDimOffset++];
164             strideDealDims++;
165             if (beginMasks[i] > 0) {
166                 beginShape[i] = 0;
167             } else {
168                 beginShape[i] = std::min(inputDim, begins[i]);
169             }
170             if (beginShape[i] < 0) {
171                 beginShape[i] += input->buffer().dim[i].extent;
172             }
173             if (endMasks[i] > 0) {
174                 endShape[i] = inputDim;
175             } else {
176                 endShape[i] = beginAndEndShapeLimit(ends[i], inputDim, true);
177             }
178             stridedShape[i] = shrinkAxisMasks[i] > 0 ? 1 : strides[i];
179 
180             if (endShape[i] < beginShape[i]) {
181                 int t         = beginShape[i];
182                 beginShape[i] = endShape[i];
183                 endShape[i]   = t;
184 
185                 MNN_ASSERT(stridedShape[i] != 0);
186                 if (stridedShape[i] < 0) {
187                     stridedShape[i] = -stridedShape[i];
188                 } else {
189                     // MNN_ASSERT(false);  // TODO: should be the wrong case, but there is one in linfeng's faster
190                     // rcnn face model
191                     beginShape[i] = endShape[i]; // TODO: temp solution
192                 }
193             }
194 
195             if (shrinkAxisMasks[i] == 0) {
196                 int size = (endShape[i] - beginShape[i] - 1) / stridedShape[i] + 1;
197                 outputShape[outputShapeSize] = size;
198                 outputShapeSize++;
199                 outputShapeShrinked[outputShapeShrinkSize] = size;
200                 outputShapeShrinkSize++;
201             } else {
202                 outputShape[outputShapeSize] = std::min(1, inputDim);
203                 outputShapeSize++;
204             }
205         }
206 
207         int outputDimensionsWithoutRemain = strideDealDims;
208         int dimensionRemained             = input->buffer().dimensions - strideDealDims;
209 
210         for (int i = 0; i < dimensionRemained; i++) {
211             outputShapeShrinked[outputShapeShrinkSize] = input->buffer().dim[outputDimensionsWithoutRemain + i].extent;
212             outputShapeShrinkSize++;
213         }
214 
215         output->buffer().dimensions    = outputShapeShrinkSize;
216         output->buffer().type          = input->buffer().type;
217 
218         for (int i = 0; i < outputShapeShrinkSize; i++) {
219             output->buffer().dim[i].extent = outputShapeShrinked[i];
220         }
221         TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
222         return true;
223     }
224 };
225 
226 REGISTER_SHAPE_INPUTS(StridedSliceComputer, OpType_StridedSlice, (std::vector<int>{1,2,3}));
227 } // namespace MNN
228